[JIT][NNC] Add handling of strides to dynamic shape support. (#70464)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70464
Add handling of strided input tensors to dynamic fusion. This is done with the same set of input striding specializations as https://github.com/pytorch/pytorch/pull/60684/:
```
S_ONE, // STRIDE_ONE: packed
S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
```
and then two additional specializations for a) contiguous tensor and b) channels-last tensor. channels-last is a common case and we should optimize for it. additionally, tensors natively store whether they are contiguous/channels-last contiguous, which makes it faster to check if tensors follow this pattern.
Output striding will be done in a follow up.
The striding is stored on both the TensorGroup node and on the guard node. The striding descriptors are stored as a vector of strings on the node for debugability and to make use of storing ivalues as attributes on nodes.
As an example:
```
%8 : Double(10, 11, 12, 13, strides=[1716, 1, 143, 11], requires_grad=0, device=cpu) = prim::TensorExprGroup_0[symbolic_shape_inputs=[-37, -36, -35, -34], striding_inputs_desc=[["TENSOR_CONT_CHANNELS_LAST"]](%x, %24, %23, %22, %21)```
```
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D33458649
Pulled By: eellison
fbshipit-source-id: c42616d3c683d70f6258180d23d3841a31a6030d
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index ee5c5d4..bcdaf99 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -475,6 +475,8 @@
_(attr, axes) \
_(attr, axis) \
_(attr, symbolic_shape_inputs) \
+ _(attr, striding_inputs_desc) \
+ _(attr, striding_outputs_desc) \
_(attr, broadcast) \
_(attr, direction) \
_(attr, ends) \
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 613a340..b910298 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -33,6 +33,26 @@
void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
+inline bool is_contiguous_strides(
+ const IntArrayRef sizes,
+ const IntArrayRef strides) {
+ int n_dim = static_cast<int>(sizes.size());
+ if (n_dim == 0) {
+ return true;
+ }
+
+ if (strides[n_dim - 1] != 1) {
+ return false;
+ }
+
+ for (int i = n_dim - 2; i >= 0; i--) {
+ if (strides[i] != strides[i + 1] * sizes[i + 1]) {
+ return false;
+ }
+ }
+ return true;
+}
+
struct AnyType;
using AnyTypePtr = SingletonTypePtr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes
@@ -639,6 +659,12 @@
return copy;
}
+ TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const {
+ auto cloned = clone();
+ cloned->strides_ = sstrides;
+ return cloned;
+ }
+
TensorTypePtr withSizesStrides(
at::IntArrayRef sizes,
at::IntArrayRef strides) const {
diff --git a/aten/src/ATen/core/tensor_type.cpp b/aten/src/ATen/core/tensor_type.cpp
index a2262a7..a363657 100644
--- a/aten/src/ATen/core/tensor_type.cpp
+++ b/aten/src/ATen/core/tensor_type.cpp
@@ -3,27 +3,6 @@
namespace c10 {
-namespace {
-
-inline bool is_contiguous_strides(
- const IntArrayRef sizes,
- const IntArrayRef strides) {
- int n_dim = static_cast<int>(sizes.size());
-
- if (n_dim == 0 || strides[n_dim-1] != 1) {
- return false;
- }
-
- for (int i = n_dim - 2; i >= 0; i--) {
- if (strides[i] != strides[i+1] * sizes[i+1]) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace
-
const TensorTypePtr& TensorType::get() {
static auto value = TensorType::create(
{}, {}, SymbolicShape(), VaryingShape<Stride>{}, {});
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index 27b9b91..0c837c0 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -2871,7 +2871,7 @@
bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
setTensorExprDynamicShapeFusionEnabled(true);
- FuseTensorExprs(graph);
+ FuseTensorExprs(graph, /*min_group_size*/2, /*add_composed_op*/true);
Code code(graph, "");
InterpreterState interpreter{code};
std::vector<IValue> stack = {a, b};
@@ -2880,7 +2880,6 @@
at::Tensor out1 = pop(stack).toTensor();
ASSERT_TRUE(at::allclose(ref1, out1));
ASSERT_TRUE(at::allclose(ref2, out2));
- graph->dump();
auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index 65b7d73..8056a77 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -69,6 +69,7 @@
subgraph->inputs().at(1)->setType(y_type);
subgraph->inputs().at(2)->setType(z_type);
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
+ subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
output->node()->addInput(x_inp);
output->node()->addInput(y_inp);
output->node()->addInput(z_inp);
@@ -268,6 +269,7 @@
auto x_type = TensorType::create(at::rand({10, 5}));
x_inp->setType(x_type);
subgraph->inputs().at(0)->setType(x_type);
+ subgraph->outputs().at(0)->setType(x_type);
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
output->node()->addInput(x_inp);
output->node()->g_(attr::Subgraph, subgraph);
diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp
index 04ac715..5421cd5 100644
--- a/test/cpp/tensorexpr/test_dynamic_shapes.cpp
+++ b/test/cpp/tensorexpr/test_dynamic_shapes.cpp
@@ -4,6 +4,7 @@
#include <torch/csrc/jit/frontend/code_template.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
@@ -48,11 +49,19 @@
// %4 : Float(SS(-2), SS(-3)) = aten::erf(%3)
// return (%4)
+ std::vector<torch::jit::StrideInput> input_desc = {
+ torch::jit::StrideInput::TENSOR_CONT};
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides;
+ symbolic_strides[x_inp] = input_desc;
std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
x_sym_dims,
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
- TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+ TensorExprKernel kernel(
+ graph, {}, symbolic_shape_inputs, false, symbolic_strides);
// Run with the same static dims as the one we initialized the graph with.
{
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
@@ -125,7 +134,17 @@
x_sym_dims,
[](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
- TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+ std::vector<torch::jit::StrideInput> input_desc = {
+ torch::jit::StrideInput::TENSOR_CONT};
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides;
+ symbolic_strides[x_inp] = input_desc;
+ symbolic_strides[y_inp] = input_desc;
+
+ TensorExprKernel kernel(
+ graph, {}, symbolic_shape_inputs, false, symbolic_strides);
// Run with the same static dims as the one we initialized the graph with.
{
@@ -205,7 +224,17 @@
std::vector<int64_t> symbolic_shape_inputs(
{x_dim0_sym.value(), x_dim1_sym.value()});
- TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+ std::vector<torch::jit::StrideInput> input_desc = {
+ torch::jit::StrideInput::TENSOR_CONT};
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides;
+ symbolic_strides[x_inp] = input_desc;
+ symbolic_strides[y_inp] = input_desc;
+
+ TensorExprKernel kernel(
+ graph, {}, symbolic_shape_inputs, false, symbolic_strides);
// Run with the same static dims as the one we initialized the graph with.
{
@@ -276,7 +305,17 @@
std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()});
- TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+ std::vector<torch::jit::StrideInput> input_desc = {
+ torch::jit::StrideInput::TENSOR_CONT};
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides;
+ symbolic_strides[x_inp] = input_desc;
+ symbolic_strides[y_inp] = input_desc;
+
+ TensorExprKernel kernel(
+ graph, {}, symbolic_shape_inputs, false, symbolic_strides);
// Run with the same static dims as the one we initialized the graph with.
{
@@ -388,7 +427,18 @@
y_dim0_sym.value(),
cat_dim0_sym.value()});
- TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+ std::vector<torch::jit::StrideInput> input_desc = {
+ torch::jit::StrideInput::TENSOR_CONT};
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides;
+ symbolic_strides[x_inp] = input_desc;
+ symbolic_strides[y_inp] = input_desc;
+ symbolic_strides[z_inp] = input_desc;
+
+ TensorExprKernel kernel(
+ graph, {}, symbolic_shape_inputs, false, symbolic_strides);
auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index af6f401..ff0881e 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -58,6 +58,15 @@
torch._C._jit_set_texpr_reductions_enabled(old)
@contextlib.contextmanager
+def texpr_dynamic_enabled():
+ old = torch._C._jit_texpr_dynamic_shape_enabled()
+ torch._C._jit_set_texpr_dynamic_shape_enabled(True)
+ try:
+ yield
+ finally:
+ torch._C._jit_set_texpr_dynamic_shape_enabled(old)
+
+@contextlib.contextmanager
def inline_fusion_groups():
old_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(True)
@@ -2000,6 +2009,77 @@
test(*args)
self.assertIn("fused_mul_add", prof.table())
+ def test_dynamic_shapes(self):
+ from functools import partial
+ n = 10
+
+ gen_tensor = (
+ lambda n: R(1, n),
+ lambda n: R(n, n),
+ lambda n: R(n, n).transpose(0, 1),
+ lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
+ lambda n: R(n, n, 2)[:, :, 0],
+ lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
+ )
+
+ with texpr_dynamic_enabled():
+ def foo(x, y, z):
+ return torch.sigmoid(torch.tanh(x))
+
+ foo.__disable_jit_function_caching__ = True
+
+ def fi(x, y, z):
+ return torch.tanh(x + y)
+
+ fi.__disable_jit_function_caching__ = True
+
+ def fum(x, y, z):
+ return torch.tanh(x + y) + z
+
+ fum.__disable_jit_function_caching__ = True
+
+ funcs = [foo, fi, fum]
+ with inline_fusion_groups():
+ # TODO: cuda ir eval error
+ for device in ['cpu']:
+ I = partial(torch.randint, 0, 100, device=device)
+ R = partial(torch.randn, device=device)
+
+ for i, func in enumerate(funcs):
+ num_args = i + 1
+ for j, gen in enumerate(gen_tensor):
+ inps = (gen(n), gen(n), gen(n))
+ func_s = torch.jit.trace(func, inps, check_trace=False)
+ torch._C._jit_pass_erase_shape_information(func_s.graph)
+ for _ in range(2):
+ x, y, z = gen(n), gen(n), gen(n)
+ func_s(x, y, z)
+
+ for incr in range(3):
+ func_s(*[gen(n + 1) for _ in range(3)])
+
+ g = torch.jit.last_executed_optimized_graph()
+ torch._C._jit_pass_inline(g)
+ torch._C._jit_pass_dce(g)
+
+ # We should see only one optimized kernel
+ FileCheck().check_count("TensorExprDynamicGuard", 1, exactly=True).run(g)
+ self.assertEqual(func(*inps), func_s(*inps))
+
+ gen = gen_tensor[0]
+ inps = (gen(n), gen(n), gen(n))
+ foo_s = torch.jit.trace(foo, inps)
+ torch._C._jit_pass_erase_shape_information(foo_s.graph)
+ g_prev = None
+ for gen in gen_tensor:
+ for i in range(3):
+ foo_s(*[gen(n + i) for _ in range(3)])
+ inps = (gen(n), gen(n), gen(n))
+ self.assertEqual(foo_s(*inps), foo(*inps))
+ g = torch.jit.last_executed_optimized_graph()
+ torch._C._jit_pass_inline(g)
+ torch._C._jit_pass_dce(g)
+ FileCheck().check_count("TensorExprDynamicGuard", len(gen_tensor), exactly=True).run(g)
works_list = [
'__radd__',
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
index 0ec2467..0f2c0b5 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
@@ -72,25 +72,127 @@
void insertDynamicShapesGuard(
const ShapeComputeGraphMapping& shape_mapping,
Node* guarded_node,
- bool add_composed_op);
+ bool add_composed_op,
+ std::vector<std::vector<StrideInput>>& input_info,
+ std::vector<StrideInput>& output_strides);
+
+std::string toString(StrideInput si) {
+ switch (si) {
+ case StrideInput::TENSOR_CONT:
+ return "TENSOR_CONT";
+ case StrideInput::TENSOR_CONT_CHANNELS_LAST:
+ return "TENSOR_CONT_CHANNELS_LAST";
+ case StrideInput::S_ONE:
+ return "S_ONE";
+ case StrideInput::S_CONT:
+ return "S_CONT";
+ case StrideInput::S_TRAN_CONT:
+ return "S_TRAN_CONT";
+ case StrideInput::S_AS_ARG:
+ return "S_AS_ARG";
+ }
+ TORCH_INTERNAL_ASSERT(false);
+}
+
+StrideInput strideInputFromString(const std::string& si) {
+ if (si == "TENSOR_CONT") {
+ return StrideInput::TENSOR_CONT;
+ } else if (si == "TENSOR_CONT_CHANNELS_LAST") {
+ return StrideInput::TENSOR_CONT_CHANNELS_LAST;
+ } else if (si == "S_ONE") {
+ return StrideInput::S_ONE;
+ } else if (si == "S_CONT") {
+ return StrideInput::S_CONT;
+ } else if (si == "S_TRAN_CONT") {
+ return StrideInput::S_TRAN_CONT;
+ } else if (si == "S_AS_ARG") {
+ return StrideInput::S_AS_ARG;
+ } else {
+ TORCH_INTERNAL_ASSERT(false);
+ }
+}
+
+// in the runtime guard, strides are serialized as one flat
+// vector. stride_inputs_offset indexes into that vector
+// where the strides of this tensor beegin
+inline StrideInput summarizeStrideDim(
+ const c10::IntArrayRef sizes,
+ const c10::IntArrayRef strides,
+ size_t dim,
+ const std::vector<StrideInput>& stride_inputs,
+ size_t stride_inputs_offset) {
+ if (strides[dim] == 1) {
+ return StrideInput::S_ONE;
+ } else if (
+ dim + 1 < sizes.size() &&
+ strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
+ return StrideInput::S_CONT;
+ // Transposed Contiguous depends on prior dim and contiguous depends on next
+ // dim, so to avoid a mutual dependence check that the next dim is Stride
+ // Contiguous
+ } else if (
+ dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
+ (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) {
+ return StrideInput::S_TRAN_CONT;
+ } else {
+ return StrideInput::S_AS_ARG;
+ }
+}
+
+std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) {
+ auto strides = *tt.strides().concrete_sizes();
+ auto sizes = *tt.sizes().concrete_sizes();
+ if (c10::is_contiguous_strides(sizes, strides)) {
+ return {StrideInput::TENSOR_CONT};
+ // TODO: channels last 3d
+ } else if (c10::is_channels_last_strides_2d(sizes, strides)) {
+ return {StrideInput::TENSOR_CONT_CHANNELS_LAST};
+ }
+ std::vector<StrideInput> stride_inputs;
+ for (size_t dim = 0; dim < sizes.size(); ++dim) {
+ stride_inputs.push_back(
+ summarizeStrideDim(sizes, strides, dim, stride_inputs, 0));
+ }
+ return stride_inputs;
+};
+
+// Todo: incorporate in codegen
+StrideInput summarizeOutputStrides(const TensorType& tt) {
+ auto strides = *tt.strides().concrete_sizes();
+ auto sizes = *tt.sizes().concrete_sizes();
+ // We only try to maintain output striding for channels last tensors,
+ // otherwise we defer to contiguous
+ // TODO: channels last 3d
+ if (c10::is_channels_last_strides_2d(sizes, strides)) {
+ return StrideInput::TENSOR_CONT_CHANNELS_LAST;
+ }
+ return StrideInput::TENSOR_CONT;
+}
// Generalize Complete Shapes inputs to Symbolic Shapes.
// Dimensions of value 1 will be preserved, otherwise
// dimensions with the same value will be bucketed to the same
// symbolic shape.
// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
-bool TryGeneralizeInputDimensionsToSymbolicShapes(
+// Also summarize input striding behavior. The Size information is stored on the
+// type, The striding is returned. See StrideInput for description of stride
+// specializations
+c10::optional<std::vector<std::vector<StrideInput>>>
+TryGeneralizeInputDimensionsToSymbolicShapes(
std::shared_ptr<Graph> tensorexpr_graph) {
std::map<size_t, int64_t> shape_to_sym_shape;
+ std::vector<std::vector<StrideInput>> input_striding;
+
for (Value* v : tensorexpr_graph->inputs()) {
if (!v->type()->cast<TensorType>()) {
continue;
}
- if (!v->type()->expect<TensorType>()->sizes().isComplete()) {
- return false;
+ auto tt = v->type()->expectRef<TensorType>();
+ if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
+ return c10::nullopt;
}
- auto tt = v->type()->expect<TensorType>();
- std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
+ input_striding.push_back(summarizeInputStrides(tt));
+ std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes();
auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
auto value = shape.value();
TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
@@ -104,9 +206,9 @@
return new_shape_symbol;
}
});
- v->setType(tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+ v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes)));
}
- return true;
+ return input_striding;
}
void moveConstantTensorsOutOfSubgraph(
@@ -162,10 +264,25 @@
moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
// Generalize Inputs
- if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
+ auto input_striding =
+ TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph);
+ if (!input_striding) {
return false;
}
+ // Get output striding behavior
+ std::vector<StrideInput> output_striding;
+ for (Value* v : tensorexpr_graph->outputs()) {
+ if (!v->type()->cast<TensorType>()) {
+ continue;
+ }
+ auto tt = v->type()->expectRef<TensorType>();
+ if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
+ return false;
+ }
+ output_striding.push_back(summarizeOutputStrides(tt));
+ }
+
// Try To Propagate Shapes
auto maybe_shape_compute_mapping =
PropagateShapesAndBuildLargeShapeComputeGraph(
@@ -178,7 +295,11 @@
// Insert Guard
insertDynamicShapesGuard(
- *maybe_shape_compute_mapping, tensorexpr_graph_node, add_composed_op);
+ *maybe_shape_compute_mapping,
+ tensorexpr_graph_node,
+ add_composed_op,
+ *input_striding,
+ output_striding);
return true;
}
@@ -219,7 +340,9 @@
void insertDynamicShapesGuard(
const ShapeComputeGraphMapping& shape_mapping,
Node* guarded_node,
- bool add_composed_op) {
+ bool add_composed_op,
+ std::vector<std::vector<StrideInput>>& input_info,
+ std::vector<StrideInput>& output_strides) {
GRAPH_DEBUG(
"Inserting a prim::TensorExprDynamicGuard guard for a node",
*guarded_node);
@@ -236,7 +359,8 @@
}
inputs_to_check.push_back(node_input);
guard_types.push_back(
- subgraph->inputs().at(i)->type()->expect<TensorType>());
+ subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides(
+ c10::VaryingShape<c10::Stride>()));
}
TORCH_INTERNAL_ASSERT(inputs_to_check.size());
@@ -307,6 +431,32 @@
}
guarded_node->is_(attr::symbolic_shape_inputs, symbolic_shape_inputs);
+ std::vector<std::vector<std::string>> input_striding;
+ for (auto& vec : input_info) {
+ auto string_info =
+ fmap(vec, [&](StrideInput inp) { return toString(inp); });
+ input_striding.push_back(string_info);
+ }
+ auto ival = IValue(input_striding);
+ guarded_node->ival_(attr::striding_inputs_desc, ival);
+ typecheck_node->ival_(attr::striding_inputs_desc, ival);
+
+ for (Value* v : subgraph->inputs()) {
+ if (auto t = v->type()->cast<TensorType>()) {
+ v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
+ }
+ }
+ for (Value* v : subgraph->outputs()) {
+ if (auto t = v->type()->cast<TensorType>()) {
+ v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
+ }
+ }
+
+ std::vector<std::string> output_striding =
+ fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
+ auto output_ival = IValue(input_striding);
+ guarded_node->ival_(attr::striding_outputs_desc, output_ival);
+
if (add_composed_op) {
// Create a TensorExprDynamicGroup node
auto te_dyn_group = SubgraphUtils::createSingletonSubgraph(
@@ -401,6 +551,18 @@
TORCH_INTERNAL_ASSERT(maybe_device);
auto device = *maybe_device;
+ // flattened vector of each inputs striding behavior
+ std::vector<StrideInput> flattened_input_striding;
+ const IValue& sym_strides = node->ival(attr::striding_inputs_desc);
+ std::vector<std::vector<std::string>> sym_strides_strs =
+ sym_strides.to<std::vector<std::vector<std::string>>>();
+ for (const auto& vec : sym_strides_strs) {
+ std::vector<StrideInput> input_desc;
+ for (const std::string& str : vec) {
+ flattened_input_striding.push_back(strideInputFromString(str));
+ }
+ }
+
for (auto type : types) {
auto tt = type->expect<TensorType>();
auto ss = tt->symbolic_sizes();
@@ -431,6 +593,7 @@
}
}
}
+
const auto num_inputs = types.size();
const auto num_symbolic_dims = sym_dim_flat_index.size();
return [num_inputs,
@@ -438,6 +601,7 @@
device,
expected_scalar_types,
flattened_input_dims,
+ flattened_input_striding,
num_symbolic_dims](Stack& stack) {
at::ArrayRef<IValue> inputs = last(stack, num_inputs);
drop(stack, num_inputs);
@@ -447,8 +611,10 @@
// each invocation or would that mess up with multithreaded
// inference since we are writing to it?
// TODO - smallvector here ?
+
std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
size_t flattened_dim_offset = 0;
+ size_t flattened_stride_offset = 0;
for (const auto i : c10::irange(num_inputs)) {
at::Tensor tensor = inputs[i].toTensor();
if (C10_UNLIKELY(
@@ -458,18 +624,51 @@
push(stack, false);
return;
}
- // TODO: striding
- if (C10_UNLIKELY(
- !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
- push(stack, false);
- return;
- }
const auto& sizes = tensor.sizes();
const auto num_dims = sizes.size();
if (C10_UNLIKELY(num_dims != expected_dims[i])) {
push(stack, false);
return;
}
+ auto striding = flattened_input_striding[flattened_stride_offset];
+ // Tensors natively store whether they are contiguous
+ // in the default memory format or in channels last,
+ // so it is more efficient to query whether they follow this
+ // property than iterating over dimensions and checking yourself
+ if (striding == StrideInput::TENSOR_CONT) {
+ if (C10_UNLIKELY(
+ !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
+ push(stack, false);
+ return;
+ }
+ flattened_stride_offset += 1;
+ } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) {
+ // TODO: 5D channels last
+ if (C10_UNLIKELY(!tensor.is_contiguous(
+ at::MemoryFormat::ChannelsLast))) {
+ push(stack, false);
+ return;
+ }
+ flattened_stride_offset += 1;
+ } else {
+ auto strides = tensor.strides();
+ for (size_t dim = 0; dim < num_dims; ++dim) {
+ auto summarized_dim = summarizeStrideDim(
+ sizes,
+ strides,
+ dim,
+ flattened_input_striding,
+ flattened_stride_offset);
+ if (C10_UNLIKELY(
+ summarized_dim !=
+ flattened_input_striding
+ [dim + flattened_stride_offset])) {
+ push(stack, false);
+ return;
+ }
+ }
+ flattened_stride_offset += num_dims;
+ }
for (const auto dim_index : c10::irange(num_dims)) {
const int64_t dim_value =
flattened_input_dims[dim_index + flattened_dim_offset];
@@ -520,6 +719,7 @@
// should be reusing Code and InterpreterState across calls to this op.
// But that is resulting in a "No frames found" error.
// TODO: Improve the performance of this by figuring out a better approach.
+ // NB: this is only run in SR, which is single-threaded
return [code](Stack& stack) {
runTensorExprDynamicGroup(code, stack);
return 0;
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
index ca03df1..0c88e11 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
@@ -3,6 +3,7 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+
#include <unordered_map>
namespace torch {
@@ -31,5 +32,24 @@
TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack);
+enum class StrideInput {
+ // Tensors natively store whether they are contiguous or not as a property
+ // this makes it faster to query `is_contiguous` or
+ // `is_contiguous(memory_format=channels_last)`
+ // than looping through the sizes/strides yourself
+ // For tensors with these properties, we only store one value:
+ TENSOR_CONT,
+ TENSOR_CONT_CHANNELS_LAST,
+ // now, we describe other cases, where there is one stride enum
+ // per dimension
+ S_ONE, // STRIDE_ONE: packed
+ S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
+ S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
+ S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
+};
+
+TORCH_API std::string toString(StrideInput si);
+TORCH_API StrideInput strideInputFromString(const std::string& si);
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index c6b917d..c77ca49 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -361,8 +361,13 @@
class TensorExprFuser {
public:
- TensorExprFuser(std::shared_ptr<Graph> graph, size_t min_group_size)
- : graph_(std::move(graph)), min_group_size_(min_group_size) {
+ TensorExprFuser(
+ std::shared_ptr<Graph> graph,
+ size_t min_group_size,
+ bool add_composed_op)
+ : graph_(std::move(graph)),
+ min_group_size_(min_group_size),
+ add_composed_op_(add_composed_op) {
parseTENotFuseOption();
}
@@ -1166,7 +1171,7 @@
}
for (Node* fusion_group : fusion_groups) {
VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
- if (!GenerateGuard(fusion_group, /*add_composed_op=*/true)) {
+ if (!GenerateGuard(fusion_group, add_composed_op_)) {
VLOG(1) << " Unfusing the fusion group because GenerateGuard failed"
<< std::endl;
SubgraphUtils::unmergeSubgraph(fusion_group);
@@ -1202,9 +1207,14 @@
std::set<NodeKind> operators_not_to_fuse;
// Minimal size of a fusion group
size_t min_group_size_;
+ // compose Runtime Type Guard and Kernel in one op
+ bool add_composed_op_;
};
-void FuseTensorExprs(std::shared_ptr<Graph>& graph, size_t min_group_size) {
+void FuseTensorExprs(
+ std::shared_ptr<Graph>& graph,
+ size_t min_group_size,
+ bool add_composed_op) {
GRAPH_DUMP("Before TExprFuser: ", graph);
// Temporary change for Block code generation.
@@ -1215,7 +1225,7 @@
// Get rid of dead code so that we don't waste effort fusing it.
EliminateDeadCode(graph);
- TensorExprFuser fuser(graph, min_group_size);
+ TensorExprFuser fuser(graph, min_group_size, add_composed_op);
fuser.run();
EliminateCommonSubexpression(graph);
@@ -1241,11 +1251,42 @@
if (node->hasAttribute(attr::symbolic_shape_inputs)) {
sym_shapes = node->is(attr::symbolic_shape_inputs);
}
+
std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction>
custom_lowerings;
auto subgraph = node->g(attr::Subgraph);
- auto kernel = std::make_shared<tensorexpr::TensorExprKernel>(
- subgraph, custom_lowerings, sym_shapes);
+ IValue sym_strides = node->ival(attr::striding_inputs_desc);
+
+ // Striding Descriptor is serialized on the node as a vector of vector of
+ // strings, translate back to StrideInput enum
+ std::vector<std::vector<std::string>> sym_strides_strs =
+ sym_strides.to<std::vector<std::vector<std::string>>>();
+ std::vector<std::vector<StrideInput>> striding_inputs;
+ for (const auto& vec : sym_strides_strs) {
+ std::vector<StrideInput> input_desc;
+ input_desc.reserve(vec.size());
+ for (const std::string& str : vec) {
+ input_desc.push_back(strideInputFromString(str));
+ }
+ striding_inputs.push_back(input_desc);
+ }
+ std::unordered_map<const Value*, std::vector<StrideInput>> stride_map;
+ size_t index = 0;
+ for (Value* v : subgraph->inputs()) {
+ if (!v->type()->cast<TensorType>()) {
+ continue;
+ }
+ stride_map[v] = striding_inputs[index];
+ index++;
+ }
+
+ std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
+ std::make_shared<tensorexpr::TensorExprKernel>(
+ subgraph,
+ custom_lowerings,
+ sym_shapes,
+ /*pre_alloc*/ false,
+ stride_map);
auto num_subgraph_inputs = subgraph->inputs().size();
return [kernel, num_subgraph_inputs](Stack& stack) {
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h
index 0f917e3..01fedc4 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.h
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.h
@@ -10,9 +10,13 @@
struct Graph;
// Run TensorExpressions-based fuser.
+// If add_composed_op is true, creates a single operation that
+// performs both the runtime check that types align
+// and then the dispatch to the kernel/unoptimized graph
TORCH_API void FuseTensorExprs(
std::shared_ptr<Graph>& graph,
- size_t min_group_size = 2);
+ size_t min_group_size = 2,
+ bool add_composed_op = false);
TORCH_API void setTensorExprFuserEnabled(bool val);
TORCH_API bool tensorExprFuserEnabled();
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 7eb4fde..2bf9b82 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -826,6 +826,12 @@
.def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
.def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
.def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
+ .def(
+ "_jit_set_texpr_dynamic_shape_enabled",
+ &setTensorExprDynamicShapeFusionEnabled)
+ .def(
+ "_jit_texpr_dynamic_shape_enabled",
+ &tensorExprDynamicShapeFusionEnabled)
.def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
.def(
"_jit_set_te_generate_block_code",
diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp
index 6080cce..6252577 100644
--- a/torch/csrc/jit/runtime/static/fusion.cpp
+++ b/torch/csrc/jit/runtime/static/fusion.cpp
@@ -330,7 +330,10 @@
GRAPH_DEBUG("Graph before tracing: ", graph);
auto traced_graph = TraceGraph(graph, sample_inputs);
GRAPH_DEBUG("Graph after tracing: ", traced_graph);
- FuseTensorExprs(traced_graph);
+ FuseTensorExprs(
+ traced_graph,
+ /*min_group_size*/ 2,
+ /*add_composed_op*/ true);
graph->block()->clear();
graph->block()->cloneFrom(traced_graph->block(), nullptr);
GRAPH_DUMP("Graph after fusion: ", graph);
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index 8b20412..3115527 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -118,6 +118,7 @@
BufferArg(Tensor tensor) : buf_(tensor.buf()) {}
BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
+ BufferArg(const BufPtr& buf) : buf_(buf) {}
VarPtr var() const {
return isVar_ ? var_ : buf_->base_handle();
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 81ca890..04f8d89 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -8,7 +8,9 @@
#include <c10/util/irange.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
+#include <torch/csrc/jit/tensorexpr/expr.h>
#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
@@ -903,102 +905,159 @@
return dims;
}
-std::vector<ExprHandle> TensorExprKernel::computeInputTensorDims(
- const torch::jit::Value* input) {
- auto tt = input->type()->expect<TensorType>();
+ExprHandle TensorExprKernel::getStrideArg(
+ size_t tensor_input_index,
+ size_t stride_index) {
+ auto it = strideArgToVar_.find(
+ std::pair<size_t, size_t>(tensor_input_index, stride_index));
+ if (it == strideArgToVar_.end()) {
+ VarHandle var(
+ "stride_arg" + std::to_string(tensor_input_index) + "_" +
+ std::to_string(stride_index),
+ kLong);
+ strideArgToVar_[std::pair<size_t, size_t>(
+ tensor_input_index, stride_index)] = var;
+ return std::move(var);
+ }
+ return it->second;
+}
- // First case - static shapes
+std::vector<ExprHandle> TensorExprKernel::getInputStrides(
+ const torch::jit::Value* input,
+ const std::vector<ExprHandle>& inputTensorDims) {
+ std::vector<ExprHandle> inputTensorStrides;
if (input->isCompleteTensor()) {
- if (isContiguous(input)) {
- return toExprHandles(*tt->sizes().concrete_sizes());
+ auto const strides =
+ input->type()->expect<TensorType>()->strides().concrete_sizes();
+ std::vector<ExprHandle> inputTensorStrides;
+ for (size_t stride : *strides) {
+ inputTensorStrides.push_back(LongImm::make(stride));
}
+ return inputTensorStrides;
+ }
- // Non-contiguous tensors are represented as 1-d buffers in NNC
- ExprHandle flat_size = 1;
- for (size_t i = 0; i < *tt->sizes().size(); i++) {
- auto size = *tt->sizes()[i];
- if (size == 0) {
- flat_size = 0;
- break;
+ size_t rank = inputTensorDims.size();
+ TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
+ std::vector<StrideInput>& stride_input = symbolic_strides_[input];
+ if (stride_input.size() == 1 &&
+ (stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
+ stride_input[0] == StrideInput::TENSOR_CONT)) {
+ auto strides = stride_input[0] == StrideInput::TENSOR_CONT
+ ? make_contiguous_strides(inputTensorDims)
+ : make_channels_last_strides(inputTensorDims);
+ return fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); });
+ }
+
+ inputTensorStrides.resize(rank);
+ std::vector<bool> stride_set;
+ for (size_t i = 0; i < rank; ++i) {
+ stride_set.push_back(false);
+ }
+ // first, generate non-dependent values
+ size_t generated_strides = 0;
+ for (const auto i : c10::irange(rank)) {
+ if (stride_input[i] == torch::jit::StrideInput::S_ONE) {
+ inputTensorStrides[i] = LongImm::make(1);
+ stride_set[i] = true;
+ generated_strides++;
+ } else if (stride_input[i] == torch::jit::StrideInput::S_AS_ARG) {
+ size_t input_index = input->offset();
+ inputTensorStrides[i] = getStrideArg(input_index, i);
+ stride_set[i] = true;
+ generated_strides++;
+ }
+ }
+ // Contiguous and Transposed Contiguous depend on adjacent values
+ while (generated_strides != rank) {
+ for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
+ if (stride_input[i] == torch::jit::StrideInput::S_CONT &&
+ stride_set[i + 1]) {
+ inputTensorStrides[i] =
+ inputTensorStrides[i + 1] * inputTensorDims[i + 1];
+
+ stride_set[i] = true;
+ generated_strides++;
}
- flat_size = flat_size + (size - 1) * *tt->strides()[i];
}
- flat_size = IRSimplifier::simplify(flat_size);
- return {flat_size};
+ for (int i = 0; i < rank; i++) {
+ if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
+ stride_set[i - 1]) {
+ inputTensorStrides[i] =
+ inputTensorStrides[i - 1] * inputTensorDims[i - 1];
+ stride_set[i] = true;
+ generated_strides++;
+ }
+ }
}
-
- // Second case - symbolic shapes
- // We only handle symbolic shape input tensors that are contiguous.
- // TODO: Handle strided tensors with symbolic shapes.
- auto const& symbolicShape = tt->symbolic_sizes();
- auto rank = symbolicShape.rank();
- if (!rank) {
- throw std::runtime_error("Symbolic shapes must have static ranks.");
- }
- std::vector<ExprHandle> inputTensorDims;
- for (const auto i : c10::irange(*rank)) {
- inputTensorDims.emplace_back(getVarForShape(symbolicShape[i]));
- }
- return inputTensorDims;
+ return inputTensorStrides;
}
Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto const& t = input->type();
auto const& outputs = input->owningGraph()->outputs();
std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
+
Tensor result(nullptr, nullptr);
switch (t->kind()) {
case TypeKind::TensorType: {
auto tt = input->type()->cast<TensorType>();
- auto inputTensorDims = computeInputTensorDims(input);
-
- BufHandle inBuffer(
- "t" + input_name_map_[input],
- inputTensorDims,
- ToDtype(static_cast<ScalarType>(*tt->scalarType())));
- bufferArgs_.emplace_back(inBuffer);
+ bool contiguous_concrete_tensor =
+ (input->isCompleteTensor() && isContiguous(input));
+ bool contiguous_strided_tensor = symbolic_strides_.count(input) &&
+ symbolic_strides_[input].size() == 1 &&
+ symbolic_strides_[input][0] == torch::jit::StrideInput::TENSOR_CONT;
// We don't need to copy the input if:
// 1) it is not an output AND
// 2) it is contiguous
- //
- // For static shapes we can check (2) directly, for symbolic shapes we
- // currently *assume* it is contiguous.
- //
- // TODO: update this logic as soon as we start supporting symbolic
- // strides.
- if (!outputs_set.count(input) &&
- (!input->isCompleteTensor() || isContiguous(input))) {
+ bool contiguous = contiguous_concrete_tensor || contiguous_strided_tensor;
+ if (!outputs_set.count(input) && contiguous) {
+ BufHandle inBuffer(
+ "t" + input_name_map_[input],
+ sizesFromSymbolicShape(tt->symbolic_sizes()),
+ ToDtype(static_cast<ScalarType>(*tt->scalarType())));
bufs_.emplace(input, inBuffer.node());
+ bufferArgs_.emplace_back(inBuffer);
break;
}
- // Symbolic shapes case:
- if (!input->isCompleteTensor()) {
- TORCH_INTERNAL_ASSERT(outputs_set.count(input));
- result = Compute(
- "input" + c10::to_string(bufs_.size() + 1),
- c10::fmap<DimArg>(inputTensorDims),
- [&](const std::vector<VarHandle>& axes) {
- return inBuffer.load(axes);
- });
- bufs_.emplace(input, result.buf());
- break;
+ // if the input isn't contiguous or is an output,
+ // write strided input into contiguous buffer that is
+ // then used in all further compute
+ std::vector<DimArg> inputTensorDims;
+ auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
+ for (size_t i = 0; i < size_handles.size(); i++) {
+ auto size = size_handles[i];
+ inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i)));
}
+ auto inputTensorStrides = getInputStrides(input, size_handles);
+ ExprHandle flat_size = 1;
+ for (size_t i = 0; i < size_handles.size(); ++i) {
+ auto size = size_handles[i];
+ if (size.AsNode<LongImm>() && immediateAs<int64_t>(size.node()) == 0) {
+ flat_size = 0;
+ break;
+ }
+ flat_size = flat_size + (size - 1) * inputTensorStrides[i];
+ }
+ flat_size = IRSimplifier::simplify(flat_size);
+ BufHandle inBuffer(
+ "t" + input_name_map_[input],
+ {flat_size},
+ ToDtype(static_cast<ScalarType>(*tt->scalarType())));
- // Static shapes, non-contiguous case:
- auto const strides = tt->strides();
result = Compute(
"input" + c10::to_string(bufs_.size() + 1),
- c10::fmap<DimArg>(*tt->sizes().concrete_sizes()),
+ inputTensorDims,
[&](const std::vector<VarHandle>& axes) {
ExprHandle idx = 0;
for (size_t i = 0; i < axes.size(); i++) {
- idx = idx + axes[i] * *strides[i];
+ idx = idx + axes[i] * inputTensorStrides[i];
}
return inBuffer.load(idx);
});
bufs_.emplace(input, result.buf());
+ bufferArgs_.emplace_back(inBuffer);
break;
}
case TypeKind::FloatType: {
@@ -1213,6 +1272,8 @@
BlockPtr TensorExprKernel::bindAllInputs() {
std::vector<CodeGen::BufferArg> symbolic_shape_args;
+ std::vector<CodeGen::BufferArg> symbolic_stride_args;
+
auto symbolic_shape_inputs_start_pos =
nInputs_ - symbolic_shape_inputs_.size();
if (has_symbolic_shapes_) {
@@ -1231,6 +1292,7 @@
// their symbolic sizes needs to be associated with these variables we
// create for the symbolic input params.
symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
+
for (size_t i = symbolic_shape_inputs_start_pos; i < nInputs_; ++i) {
auto input = graph_->inputs()[i];
if (input->type()->kind() != TypeKind::IntType) {
@@ -1247,6 +1309,25 @@
shapeSymbolToVar_[symbolic_shape_inputs_[i]] =
scalars_[graph_->inputs()[symbolic_shape_inputs_start_pos + i]];
}
+
+ // Next, process symbolic input params and create an argument for symbolic
+ for (size_t i = 0; i < symbolic_shape_inputs_start_pos; ++i) {
+ auto input = graph_->inputs()[i];
+ auto tt = input->type()->cast<TensorType>();
+ if (!tt) {
+ continue;
+ }
+ TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
+ auto symbolic_stride = symbolic_strides_[input];
+ for (size_t j = 0; j < symbolic_stride.size(); ++j) {
+ if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
+ VarHandle v("v" + input_name_map_[input], kLong);
+ symbolic_stride_args.emplace_back(v);
+ strideArgToVar_[{i, j}] = v;
+ input_stride_args_.emplace_back(i, j);
+ }
+ }
+ }
}
// Block to collect the Stmts corresponding to all tensors.
@@ -1265,6 +1346,13 @@
bufferArgs_.end(),
symbolic_shape_args.begin(),
symbolic_shape_args.end());
+
+ // Now, add all the variables corresponding to symbolic stride inputs
+ bufferArgs_.insert(
+ bufferArgs_.end(),
+ symbolic_stride_args.begin(),
+ symbolic_stride_args.end());
+
return block;
}
@@ -1375,14 +1463,19 @@
const std::string& kernel_func_name,
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
std::vector<int64_t> symbolic_shape_inputs,
- bool pre_alloc /*= false*/)
+ bool pre_alloc /*= false*/,
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>> symbolic_strides)
: graph_(subgraph),
code_(subgraph, ""),
symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
custom_lowerings_(std::move(custom_lowerings)),
pre_alloc_(pre_alloc),
- kernel_func_name_(kernel_func_name) {
+ kernel_func_name_(kernel_func_name),
+ symbolic_strides_(std::move(symbolic_strides)) {
allow_fallback_ = fallbackAllowed();
+
if (!allow_fallback_) {
compile();
return;
@@ -1447,7 +1540,8 @@
// TODO: preallocate `runArgs` during compilation and fill in values where
// possible (e.g. for constant tensors)
std::vector<CodeGen::CallArg> runArgs;
- runArgs.reserve(inputs.size() + bufOutputs_.size());
+ runArgs.reserve(
+ inputs.size() + input_stride_args_.size() + bufOutputs_.size());
for (auto& input : inputs) {
if (input.isInt()) {
@@ -1461,6 +1555,13 @@
if (has_symbolic_shapes_) {
updateOutputSizesAndStrides(inputs);
+
+ // add stride args
+ for (const auto& input_stride_arg : input_stride_args_) {
+ runArgs.emplace_back(
+ inputs[input_stride_arg.first].toTensor().strides().at(
+ input_stride_arg.second));
+ }
}
for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index aa54cdb..5b43eee 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -2,6 +2,7 @@
#include <c10/util/variant.h>
#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/tensorexpr/analysis.h>
@@ -13,6 +14,14 @@
namespace jit {
namespace tensorexpr {
+struct SmallSizeTPairHash {
+ public:
+ std::size_t operator()(const std::pair<size_t, size_t>& x) const {
+ // hashing input index and then dim index
+ return x.first * 128 + x.second;
+ }
+};
+
// Returns true if the TE fuser supports this conv2d.
bool conv2dIsSupportedJit(const Node* node);
// Returns true if the TE fuser supports this matmul.
@@ -114,20 +123,27 @@
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
{},
std::vector<int64_t> symbolic_shape_inputs = {},
- bool pre_alloc = false);
+ bool pre_alloc = false,
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>> symbolic_strides = {});
explicit TensorExprKernel(
const std::shared_ptr<Graph>& subgraph,
std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
{},
std::vector<int64_t> symbolic_shape_inputs = {},
- bool pre_alloc = false)
+ bool pre_alloc = false,
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>> symbolic_strides = {})
: TensorExprKernel(
subgraph,
SubgraphUtils::generateNameForGraph(subgraph),
custom_lowerings,
symbolic_shape_inputs,
- pre_alloc) {}
+ pre_alloc,
+ symbolic_strides) {}
void run(Stack& stack);
void runFast(
@@ -243,8 +259,12 @@
ExprHandle getVarForShape(const c10::ShapeSymbol& ss);
std::vector<ExprHandle> computeInputTensorDims(
const torch::jit::Value* input);
+ ExprHandle getStrideArg(size_t tensor_input, size_t stride_index);
std::vector<ExprHandle> sizesFromSymbolicShape(
const c10::SymbolicShape& shape);
+ std::vector<ExprHandle> getInputStrides(
+ const torch::jit::Value* input,
+ const std::vector<ExprHandle>& inputTensorDims);
int64_t nInputs_ = 0;
int64_t nOutputs_ = 0;
@@ -284,6 +304,17 @@
StmtPtr stmt_ = nullptr;
bool pre_alloc_{false};
std::string kernel_func_name_;
+
+ // index of stack, stride index of tensor that will be appended as a codegen
+ // arg
+ std::vector<std::pair<size_t, size_t>> input_stride_args_;
+ // map from <input index, tensor dimension> to stride as arg VarHandle
+ std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
+ strideArgToVar_;
+ std::unordered_map<
+ const torch::jit::Value*,
+ std::vector<torch::jit::StrideInput>>
+ symbolic_strides_;
};
TORCH_API int& getTECudaPointwiseLoopLevels();