[JIT][NNC] Add handling of strides to dynamic shape support. (#70464)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70464

Add handling of strided input tensors to dynamic fusion. This is done with the same set of input striding specializations as https://github.com/pytorch/pytorch/pull/60684/:
```
  S_ONE, // STRIDE_ONE: packed
  S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
  S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
  S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
```
and then two additional specializations for a) contiguous tensor and b) channels-last tensor. channels-last is a common case and we should optimize for it. additionally, tensors natively store whether they are contiguous/channels-last contiguous, which makes it faster to check if tensors follow this pattern.

Output striding will be done in a follow up.

The striding is stored on both the TensorGroup node and on the guard node. The striding descriptors are stored as a vector of strings on the node for debugability and to make use of storing ivalues as attributes on nodes.

As an example:

```

%8 : Double(10, 11, 12, 13, strides=[1716, 1, 143, 11], requires_grad=0, device=cpu) = prim::TensorExprGroup_0[symbolic_shape_inputs=[-37, -36, -35, -34], striding_inputs_desc=[["TENSOR_CONT_CHANNELS_LAST"]](%x, %24, %23, %22, %21)```
```

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D33458649

Pulled By: eellison

fbshipit-source-id: c42616d3c683d70f6258180d23d3841a31a6030d
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index ee5c5d4..bcdaf99 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -475,6 +475,8 @@
   _(attr, axes)                      \
   _(attr, axis)                      \
   _(attr, symbolic_shape_inputs)     \
+  _(attr, striding_inputs_desc)      \
+  _(attr, striding_outputs_desc)     \
   _(attr, broadcast)                 \
   _(attr, direction)                 \
   _(attr, ends)                      \
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 613a340..b910298 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -33,6 +33,26 @@
 void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill);
 void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
 
+inline bool is_contiguous_strides(
+    const IntArrayRef sizes,
+    const IntArrayRef strides) {
+  int n_dim = static_cast<int>(sizes.size());
+  if (n_dim == 0) {
+    return true;
+  }
+
+  if (strides[n_dim - 1] != 1) {
+    return false;
+  }
+
+  for (int i = n_dim - 2; i >= 0; i--) {
+    if (strides[i] != strides[i + 1] * sizes[i + 1]) {
+      return false;
+    }
+  }
+  return true;
+}
+
 struct AnyType;
 using AnyTypePtr = SingletonTypePtr<AnyType>;
 // Any is the top of the type hierarchy, all other types are subtypes
@@ -639,6 +659,12 @@
     return copy;
   }
 
+  TensorTypePtr withStrides(VaryingShape<Stride> sstrides) const {
+    auto cloned = clone();
+    cloned->strides_ = sstrides;
+    return cloned;
+  }
+
   TensorTypePtr withSizesStrides(
       at::IntArrayRef sizes,
       at::IntArrayRef strides) const {
diff --git a/aten/src/ATen/core/tensor_type.cpp b/aten/src/ATen/core/tensor_type.cpp
index a2262a7..a363657 100644
--- a/aten/src/ATen/core/tensor_type.cpp
+++ b/aten/src/ATen/core/tensor_type.cpp
@@ -3,27 +3,6 @@
 
 namespace c10 {
 
-namespace {
-
-inline bool is_contiguous_strides(
-    const IntArrayRef sizes,
-    const IntArrayRef strides) {
-  int n_dim = static_cast<int>(sizes.size());
-
-  if (n_dim == 0 || strides[n_dim-1] != 1) {
-    return false;
-  }
-
-  for (int i = n_dim - 2; i >= 0; i--) {
-    if (strides[i] != strides[i+1] * sizes[i+1]) {
-      return false;
-    }
-  }
-  return true;
-}
-
-} // namespace
-
 const TensorTypePtr& TensorType::get() {
   static auto value = TensorType::create(
       {}, {}, SymbolicShape(), VaryingShape<Stride>{}, {});
diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp
index 27b9b91..0c837c0 100644
--- a/test/cpp/jit/test_misc.cpp
+++ b/test/cpp/jit/test_misc.cpp
@@ -2871,7 +2871,7 @@
   bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
   torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
   setTensorExprDynamicShapeFusionEnabled(true);
-  FuseTensorExprs(graph);
+  FuseTensorExprs(graph, /*min_group_size*/2, /*add_composed_op*/true);
   Code code(graph, "");
   InterpreterState interpreter{code};
   std::vector<IValue> stack = {a, b};
@@ -2880,7 +2880,6 @@
   at::Tensor out1 = pop(stack).toTensor();
   ASSERT_TRUE(at::allclose(ref1, out1));
   ASSERT_TRUE(at::allclose(ref2, out2));
-  graph->dump();
 
   auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
   auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
diff --git a/test/cpp/jit/test_shape_analysis.cpp b/test/cpp/jit/test_shape_analysis.cpp
index 65b7d73..8056a77 100644
--- a/test/cpp/jit/test_shape_analysis.cpp
+++ b/test/cpp/jit/test_shape_analysis.cpp
@@ -69,6 +69,7 @@
   subgraph->inputs().at(1)->setType(y_type);
   subgraph->inputs().at(2)->setType(z_type);
   auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
+  subgraph->outputs().at(0)->setType(TensorType::create(at::rand({14, 5})));
   output->node()->addInput(x_inp);
   output->node()->addInput(y_inp);
   output->node()->addInput(z_inp);
@@ -268,6 +269,7 @@
   auto x_type = TensorType::create(at::rand({10, 5}));
   x_inp->setType(x_type);
   subgraph->inputs().at(0)->setType(x_type);
+  subgraph->outputs().at(0)->setType(x_type);
   auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
   output->node()->addInput(x_inp);
   output->node()->g_(attr::Subgraph, subgraph);
diff --git a/test/cpp/tensorexpr/test_dynamic_shapes.cpp b/test/cpp/tensorexpr/test_dynamic_shapes.cpp
index 04ac715..5421cd5 100644
--- a/test/cpp/tensorexpr/test_dynamic_shapes.cpp
+++ b/test/cpp/tensorexpr/test_dynamic_shapes.cpp
@@ -4,6 +4,7 @@
 #include <torch/csrc/jit/frontend/code_template.h>
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/ir/irparser.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
 #include <torch/csrc/jit/tensorexpr/kernel.h>
 #include <torch/csrc/jit/testing/file_check.h>
 #include <torch/torch.h>
@@ -48,11 +49,19 @@
   //   %4 : Float(SS(-2), SS(-3)) = aten::erf(%3)
   //   return (%4)
 
+  std::vector<torch::jit::StrideInput> input_desc = {
+      torch::jit::StrideInput::TENSOR_CONT};
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides;
+  symbolic_strides[x_inp] = input_desc;
   std::vector<int64_t> symbolic_shape_inputs = c10::fmap(
       x_sym_dims,
       [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
 
-  TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+  TensorExprKernel kernel(
+      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
   // Run with the same static dims as the one we initialized the graph with.
   {
     auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
@@ -125,7 +134,17 @@
       x_sym_dims,
       [](const c10::ShapeSymbol& shapeSym) { return shapeSym.value(); });
 
-  TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+  std::vector<torch::jit::StrideInput> input_desc = {
+      torch::jit::StrideInput::TENSOR_CONT};
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides;
+  symbolic_strides[x_inp] = input_desc;
+  symbolic_strides[y_inp] = input_desc;
+
+  TensorExprKernel kernel(
+      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
 
   // Run with the same static dims as the one we initialized the graph with.
   {
@@ -205,7 +224,17 @@
   std::vector<int64_t> symbolic_shape_inputs(
       {x_dim0_sym.value(), x_dim1_sym.value()});
 
-  TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+  std::vector<torch::jit::StrideInput> input_desc = {
+      torch::jit::StrideInput::TENSOR_CONT};
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides;
+  symbolic_strides[x_inp] = input_desc;
+  symbolic_strides[y_inp] = input_desc;
+
+  TensorExprKernel kernel(
+      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
 
   // Run with the same static dims as the one we initialized the graph with.
   {
@@ -276,7 +305,17 @@
 
   std::vector<int64_t> symbolic_shape_inputs({x_dim1_sym.value()});
 
-  TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+  std::vector<torch::jit::StrideInput> input_desc = {
+      torch::jit::StrideInput::TENSOR_CONT};
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides;
+  symbolic_strides[x_inp] = input_desc;
+  symbolic_strides[y_inp] = input_desc;
+
+  TensorExprKernel kernel(
+      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
 
   // Run with the same static dims as the one we initialized the graph with.
   {
@@ -388,7 +427,18 @@
        y_dim0_sym.value(),
        cat_dim0_sym.value()});
 
-  TensorExprKernel kernel(graph, {}, symbolic_shape_inputs);
+  std::vector<torch::jit::StrideInput> input_desc = {
+      torch::jit::StrideInput::TENSOR_CONT};
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides;
+  symbolic_strides[x_inp] = input_desc;
+  symbolic_strides[y_inp] = input_desc;
+  symbolic_strides[z_inp] = input_desc;
+
+  TensorExprKernel kernel(
+      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
 
   auto a = at::rand({10, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
   auto b = at::rand({4, 5}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py
index af6f401..ff0881e 100644
--- a/test/test_jit_fuser_te.py
+++ b/test/test_jit_fuser_te.py
@@ -58,6 +58,15 @@
         torch._C._jit_set_texpr_reductions_enabled(old)
 
 @contextlib.contextmanager
+def texpr_dynamic_enabled():
+    old = torch._C._jit_texpr_dynamic_shape_enabled()
+    torch._C._jit_set_texpr_dynamic_shape_enabled(True)
+    try:
+        yield
+    finally:
+        torch._C._jit_set_texpr_dynamic_shape_enabled(old)
+
+@contextlib.contextmanager
 def inline_fusion_groups():
     old_inlining = torch._C._debug_get_fusion_group_inlining()
     torch._C._debug_set_fusion_group_inlining(True)
@@ -2000,6 +2009,77 @@
                 test(*args)
         self.assertIn("fused_mul_add", prof.table())
 
+    def test_dynamic_shapes(self):
+        from functools import partial
+        n = 10
+
+        gen_tensor = (
+            lambda n: R(1, n),
+            lambda n: R(n, n),
+            lambda n: R(n, n).transpose(0, 1),
+            lambda n: R(n + 1, n + 1, 2)[:n, n, 0],
+            lambda n: R(n, n, 2)[:, :, 0],
+            lambda n: R(n, n + 1, n + 2, n + 3).to(memory_format=torch.channels_last),
+        )
+
+        with texpr_dynamic_enabled():
+            def foo(x, y, z):
+                return torch.sigmoid(torch.tanh(x))
+
+            foo.__disable_jit_function_caching__ = True
+
+            def fi(x, y, z):
+                return torch.tanh(x + y)
+
+            fi.__disable_jit_function_caching__ = True
+
+            def fum(x, y, z):
+                return torch.tanh(x + y) + z
+
+            fum.__disable_jit_function_caching__ = True
+
+            funcs = [foo, fi, fum]
+            with inline_fusion_groups():
+                # TODO: cuda ir eval error
+                for device in ['cpu']:
+                    I = partial(torch.randint, 0, 100, device=device)
+                    R = partial(torch.randn, device=device)
+
+                    for i, func in enumerate(funcs):
+                        num_args = i + 1
+                        for j, gen in enumerate(gen_tensor):
+                            inps = (gen(n), gen(n), gen(n))
+                            func_s = torch.jit.trace(func, inps, check_trace=False)
+                            torch._C._jit_pass_erase_shape_information(func_s.graph)
+                            for _ in range(2):
+                                x, y, z = gen(n), gen(n), gen(n)
+                                func_s(x, y, z)
+
+                            for incr in range(3):
+                                func_s(*[gen(n + 1) for _ in range(3)])
+
+                            g = torch.jit.last_executed_optimized_graph()
+                            torch._C._jit_pass_inline(g)
+                            torch._C._jit_pass_dce(g)
+
+                            # We should see only one optimized kernel
+                            FileCheck().check_count("TensorExprDynamicGuard", 1, exactly=True).run(g)
+                            self.assertEqual(func(*inps), func_s(*inps))
+
+                    gen = gen_tensor[0]
+                    inps = (gen(n), gen(n), gen(n))
+                    foo_s = torch.jit.trace(foo, inps)
+                    torch._C._jit_pass_erase_shape_information(foo_s.graph)
+                    g_prev = None
+                    for gen in gen_tensor:
+                        for i in range(3):
+                            foo_s(*[gen(n + i) for _ in range(3)])
+                            inps = (gen(n), gen(n), gen(n))
+                            self.assertEqual(foo_s(*inps), foo(*inps))
+                    g = torch.jit.last_executed_optimized_graph()
+                    torch._C._jit_pass_inline(g)
+                    torch._C._jit_pass_dce(g)
+                    FileCheck().check_count("TensorExprDynamicGuard", len(gen_tensor), exactly=True).run(g)
 
 works_list = [
     '__radd__',
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
index 0ec2467..0f2c0b5 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
@@ -72,25 +72,127 @@
 void insertDynamicShapesGuard(
     const ShapeComputeGraphMapping& shape_mapping,
     Node* guarded_node,
-    bool add_composed_op);
+    bool add_composed_op,
+    std::vector<std::vector<StrideInput>>& input_info,
+    std::vector<StrideInput>& output_strides);
+
+std::string toString(StrideInput si) {
+  switch (si) {
+    case StrideInput::TENSOR_CONT:
+      return "TENSOR_CONT";
+    case StrideInput::TENSOR_CONT_CHANNELS_LAST:
+      return "TENSOR_CONT_CHANNELS_LAST";
+    case StrideInput::S_ONE:
+      return "S_ONE";
+    case StrideInput::S_CONT:
+      return "S_CONT";
+    case StrideInput::S_TRAN_CONT:
+      return "S_TRAN_CONT";
+    case StrideInput::S_AS_ARG:
+      return "S_AS_ARG";
+  }
+  TORCH_INTERNAL_ASSERT(false);
+}
+
+StrideInput strideInputFromString(const std::string& si) {
+  if (si == "TENSOR_CONT") {
+    return StrideInput::TENSOR_CONT;
+  } else if (si == "TENSOR_CONT_CHANNELS_LAST") {
+    return StrideInput::TENSOR_CONT_CHANNELS_LAST;
+  } else if (si == "S_ONE") {
+    return StrideInput::S_ONE;
+  } else if (si == "S_CONT") {
+    return StrideInput::S_CONT;
+  } else if (si == "S_TRAN_CONT") {
+    return StrideInput::S_TRAN_CONT;
+  } else if (si == "S_AS_ARG") {
+    return StrideInput::S_AS_ARG;
+  } else {
+    TORCH_INTERNAL_ASSERT(false);
+  }
+}
+
+// in the runtime guard, strides are serialized as one flat
+// vector. stride_inputs_offset indexes into that vector
+// where the strides of this tensor beegin
+inline StrideInput summarizeStrideDim(
+    const c10::IntArrayRef sizes,
+    const c10::IntArrayRef strides,
+    size_t dim,
+    const std::vector<StrideInput>& stride_inputs,
+    size_t stride_inputs_offset) {
+  if (strides[dim] == 1) {
+    return StrideInput::S_ONE;
+  } else if (
+      dim + 1 < sizes.size() &&
+      strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
+    return StrideInput::S_CONT;
+    // Transposed Contiguous depends on prior dim and contiguous depends on next
+    // dim, so to avoid a mutual dependence check that the next dim is Stride
+    // Contiguous
+  } else if (
+      dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
+      (stride_inputs[dim - 1 + stride_inputs_offset] != StrideInput::S_CONT)) {
+    return StrideInput::S_TRAN_CONT;
+  } else {
+    return StrideInput::S_AS_ARG;
+  }
+}
+
+std::vector<StrideInput> summarizeInputStrides(const TensorType& tt) {
+  auto strides = *tt.strides().concrete_sizes();
+  auto sizes = *tt.sizes().concrete_sizes();
+  if (c10::is_contiguous_strides(sizes, strides)) {
+    return {StrideInput::TENSOR_CONT};
+    // TODO: channels last 3d
+  } else if (c10::is_channels_last_strides_2d(sizes, strides)) {
+    return {StrideInput::TENSOR_CONT_CHANNELS_LAST};
+  }
+  std::vector<StrideInput> stride_inputs;
+  for (size_t dim = 0; dim < sizes.size(); ++dim) {
+    stride_inputs.push_back(
+        summarizeStrideDim(sizes, strides, dim, stride_inputs, 0));
+  }
+  return stride_inputs;
+};
+
+// Todo: incorporate in codegen
+StrideInput summarizeOutputStrides(const TensorType& tt) {
+  auto strides = *tt.strides().concrete_sizes();
+  auto sizes = *tt.sizes().concrete_sizes();
+  // We only try to maintain output striding for channels last tensors,
+  // otherwise we defer to contiguous
+  // TODO: channels last 3d
+  if (c10::is_channels_last_strides_2d(sizes, strides)) {
+    return StrideInput::TENSOR_CONT_CHANNELS_LAST;
+  }
+  return StrideInput::TENSOR_CONT;
+}
 
 // Generalize Complete Shapes inputs to Symbolic Shapes.
 // Dimensions of value 1 will be preserved, otherwise
 // dimensions with the same value will be bucketed to the same
 // symbolic shape.
 // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
-bool TryGeneralizeInputDimensionsToSymbolicShapes(
+// Also summarize input striding behavior. The Size information is stored on the
+// type, The striding is returned. See StrideInput for description of stride
+// specializations
+c10::optional<std::vector<std::vector<StrideInput>>>
+TryGeneralizeInputDimensionsToSymbolicShapes(
     std::shared_ptr<Graph> tensorexpr_graph) {
   std::map<size_t, int64_t> shape_to_sym_shape;
+  std::vector<std::vector<StrideInput>> input_striding;
+
   for (Value* v : tensorexpr_graph->inputs()) {
     if (!v->type()->cast<TensorType>()) {
       continue;
     }
-    if (!v->type()->expect<TensorType>()->sizes().isComplete()) {
-      return false;
+    auto tt = v->type()->expectRef<TensorType>();
+    if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
+      return c10::nullopt;
     }
-    auto tt = v->type()->expect<TensorType>();
-    std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
+    input_striding.push_back(summarizeInputStrides(tt));
+    std::vector<at::ShapeSymbol> shape_vec = *tt.symbolic_sizes().sizes();
     auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
       auto value = shape.value();
       TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
@@ -104,9 +206,9 @@
         return new_shape_symbol;
       }
     });
-    v->setType(tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
+    v->setType(tt.withSymbolicShapes(c10::SymbolicShape(new_sizes)));
   }
-  return true;
+  return input_striding;
 }
 
 void moveConstantTensorsOutOfSubgraph(
@@ -162,10 +264,25 @@
   moveConstantTensorsOutOfSubgraph(tensorexpr_graph_node, tensorexpr_graph);
 
   // Generalize Inputs
-  if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
+  auto input_striding =
+      TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph);
+  if (!input_striding) {
     return false;
   }
 
+  // Get output striding behavior
+  std::vector<StrideInput> output_striding;
+  for (Value* v : tensorexpr_graph->outputs()) {
+    if (!v->type()->cast<TensorType>()) {
+      continue;
+    }
+    auto tt = v->type()->expectRef<TensorType>();
+    if (!tt.sizes().isComplete() || !tt.strides().isComplete()) {
+      return false;
+    }
+    output_striding.push_back(summarizeOutputStrides(tt));
+  }
+
   // Try To Propagate Shapes
   auto maybe_shape_compute_mapping =
       PropagateShapesAndBuildLargeShapeComputeGraph(
@@ -178,7 +295,11 @@
 
   // Insert Guard
   insertDynamicShapesGuard(
-      *maybe_shape_compute_mapping, tensorexpr_graph_node, add_composed_op);
+      *maybe_shape_compute_mapping,
+      tensorexpr_graph_node,
+      add_composed_op,
+      *input_striding,
+      output_striding);
   return true;
 }
 
@@ -219,7 +340,9 @@
 void insertDynamicShapesGuard(
     const ShapeComputeGraphMapping& shape_mapping,
     Node* guarded_node,
-    bool add_composed_op) {
+    bool add_composed_op,
+    std::vector<std::vector<StrideInput>>& input_info,
+    std::vector<StrideInput>& output_strides) {
   GRAPH_DEBUG(
       "Inserting a prim::TensorExprDynamicGuard guard for a node",
       *guarded_node);
@@ -236,7 +359,8 @@
     }
     inputs_to_check.push_back(node_input);
     guard_types.push_back(
-        subgraph->inputs().at(i)->type()->expect<TensorType>());
+        subgraph->inputs().at(i)->type()->expect<TensorType>()->withStrides(
+            c10::VaryingShape<c10::Stride>()));
   }
   TORCH_INTERNAL_ASSERT(inputs_to_check.size());
 
@@ -307,6 +431,32 @@
   }
   guarded_node->is_(attr::symbolic_shape_inputs, symbolic_shape_inputs);
 
+  std::vector<std::vector<std::string>> input_striding;
+  for (auto& vec : input_info) {
+    auto string_info =
+        fmap(vec, [&](StrideInput inp) { return toString(inp); });
+    input_striding.push_back(string_info);
+  }
+  auto ival = IValue(input_striding);
+  guarded_node->ival_(attr::striding_inputs_desc, ival);
+  typecheck_node->ival_(attr::striding_inputs_desc, ival);
+
+  for (Value* v : subgraph->inputs()) {
+    if (auto t = v->type()->cast<TensorType>()) {
+      v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
+    }
+  }
+  for (Value* v : subgraph->outputs()) {
+    if (auto t = v->type()->cast<TensorType>()) {
+      v->setType(t->withStrides(c10::VaryingShape<c10::Stride>()));
+    }
+  }
+
+  std::vector<std::string> output_striding =
+      fmap(output_strides, [&](StrideInput inp) { return toString(inp); });
+  auto output_ival = IValue(input_striding);
+  guarded_node->ival_(attr::striding_outputs_desc, output_ival);
+
   if (add_composed_op) {
     // Create a TensorExprDynamicGroup node
     auto te_dyn_group = SubgraphUtils::createSingletonSubgraph(
@@ -401,6 +551,18 @@
           TORCH_INTERNAL_ASSERT(maybe_device);
           auto device = *maybe_device;
 
+          // flattened vector of each inputs striding behavior
+          std::vector<StrideInput> flattened_input_striding;
+          const IValue& sym_strides = node->ival(attr::striding_inputs_desc);
+          std::vector<std::vector<std::string>> sym_strides_strs =
+              sym_strides.to<std::vector<std::vector<std::string>>>();
+          for (const auto& vec : sym_strides_strs) {
+            std::vector<StrideInput> input_desc;
+            for (const std::string& str : vec) {
+              flattened_input_striding.push_back(strideInputFromString(str));
+            }
+          }
+
           for (auto type : types) {
             auto tt = type->expect<TensorType>();
             auto ss = tt->symbolic_sizes();
@@ -431,6 +593,7 @@
               }
             }
           }
+
           const auto num_inputs = types.size();
           const auto num_symbolic_dims = sym_dim_flat_index.size();
           return [num_inputs,
@@ -438,6 +601,7 @@
                   device,
                   expected_scalar_types,
                   flattened_input_dims,
+                  flattened_input_striding,
                   num_symbolic_dims](Stack& stack) {
             at::ArrayRef<IValue> inputs = last(stack, num_inputs);
             drop(stack, num_inputs);
@@ -447,8 +611,10 @@
             // each invocation or would that mess up with multithreaded
             // inference since we are writing to it?
             // TODO - smallvector here ?
+
             std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
             size_t flattened_dim_offset = 0;
+            size_t flattened_stride_offset = 0;
             for (const auto i : c10::irange(num_inputs)) {
               at::Tensor tensor = inputs[i].toTensor();
               if (C10_UNLIKELY(
@@ -458,18 +624,51 @@
                 push(stack, false);
                 return;
               }
-              // TODO: striding
-              if (C10_UNLIKELY(
-                      !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
-                push(stack, false);
-                return;
-              }
               const auto& sizes = tensor.sizes();
               const auto num_dims = sizes.size();
               if (C10_UNLIKELY(num_dims != expected_dims[i])) {
                 push(stack, false);
                 return;
               }
+              auto striding = flattened_input_striding[flattened_stride_offset];
+              // Tensors natively store whether they are contiguous
+              // in the default memory format or in channels last,
+              // so it is more efficient to query whether they follow this
+              // property than iterating over dimensions and checking yourself
+              if (striding == StrideInput::TENSOR_CONT) {
+                if (C10_UNLIKELY(
+                        !tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
+                  push(stack, false);
+                  return;
+                }
+                flattened_stride_offset += 1;
+              } else if (striding == StrideInput::TENSOR_CONT_CHANNELS_LAST) {
+                // TODO: 5D channels last
+                if (C10_UNLIKELY(!tensor.is_contiguous(
+                        at::MemoryFormat::ChannelsLast))) {
+                  push(stack, false);
+                  return;
+                }
+                flattened_stride_offset += 1;
+              } else {
+                auto strides = tensor.strides();
+                for (size_t dim = 0; dim < num_dims; ++dim) {
+                  auto summarized_dim = summarizeStrideDim(
+                      sizes,
+                      strides,
+                      dim,
+                      flattened_input_striding,
+                      flattened_stride_offset);
+                  if (C10_UNLIKELY(
+                          summarized_dim !=
+                          flattened_input_striding
+                              [dim + flattened_stride_offset])) {
+                    push(stack, false);
+                    return;
+                  }
+                }
+                flattened_stride_offset += num_dims;
+              }
               for (const auto dim_index : c10::irange(num_dims)) {
                 const int64_t dim_value =
                     flattened_input_dims[dim_index + flattened_dim_offset];
@@ -520,6 +719,7 @@
   // should be reusing Code and InterpreterState across calls to this op.
   // But that is resulting in a "No frames found" error.
   // TODO: Improve the performance of this by figuring out a better approach.
+  // NB: this is only run in SR, which is single-threaded
   return [code](Stack& stack) {
     runTensorExprDynamicGroup(code, stack);
     return 0;
diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
index ca03df1..0c88e11 100644
--- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
+++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
@@ -3,6 +3,7 @@
 #include <torch/csrc/Export.h>
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
+
 #include <unordered_map>
 
 namespace torch {
@@ -31,5 +32,24 @@
 
 TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack);
 
+enum class StrideInput {
+  // Tensors natively store whether they are contiguous or not as a property
+  // this makes it faster to query `is_contiguous` or
+  // `is_contiguous(memory_format=channels_last)`
+  // than looping through the sizes/strides yourself
+  // For tensors with these properties, we only store one value:
+  TENSOR_CONT,
+  TENSOR_CONT_CHANNELS_LAST,
+  // now, we describe other cases, where there is one stride enum
+  // per dimension
+  S_ONE, // STRIDE_ONE: packed
+  S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
+  S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
+  S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
+};
+
+TORCH_API std::string toString(StrideInput si);
+TORCH_API StrideInput strideInputFromString(const std::string& si);
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index c6b917d..c77ca49 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -361,8 +361,13 @@
 
 class TensorExprFuser {
  public:
-  TensorExprFuser(std::shared_ptr<Graph> graph, size_t min_group_size)
-      : graph_(std::move(graph)), min_group_size_(min_group_size) {
+  TensorExprFuser(
+      std::shared_ptr<Graph> graph,
+      size_t min_group_size,
+      bool add_composed_op)
+      : graph_(std::move(graph)),
+        min_group_size_(min_group_size),
+        add_composed_op_(add_composed_op) {
     parseTENotFuseOption();
   }
 
@@ -1166,7 +1171,7 @@
     }
     for (Node* fusion_group : fusion_groups) {
       VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
-      if (!GenerateGuard(fusion_group, /*add_composed_op=*/true)) {
+      if (!GenerateGuard(fusion_group, add_composed_op_)) {
         VLOG(1) << "  Unfusing the fusion group because GenerateGuard failed"
                 << std::endl;
         SubgraphUtils::unmergeSubgraph(fusion_group);
@@ -1202,9 +1207,14 @@
   std::set<NodeKind> operators_not_to_fuse;
   // Minimal size of a fusion group
   size_t min_group_size_;
+  // compose Runtime Type Guard and Kernel in one op
+  bool add_composed_op_;
 };
 
-void FuseTensorExprs(std::shared_ptr<Graph>& graph, size_t min_group_size) {
+void FuseTensorExprs(
+    std::shared_ptr<Graph>& graph,
+    size_t min_group_size,
+    bool add_composed_op) {
   GRAPH_DUMP("Before TExprFuser: ", graph);
 
   // Temporary change for Block code generation.
@@ -1215,7 +1225,7 @@
   // Get rid of dead code so that we don't waste effort fusing it.
   EliminateDeadCode(graph);
 
-  TensorExprFuser fuser(graph, min_group_size);
+  TensorExprFuser fuser(graph, min_group_size, add_composed_op);
   fuser.run();
 
   EliminateCommonSubexpression(graph);
@@ -1241,11 +1251,42 @@
   if (node->hasAttribute(attr::symbolic_shape_inputs)) {
     sym_shapes = node->is(attr::symbolic_shape_inputs);
   }
+
   std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction>
       custom_lowerings;
   auto subgraph = node->g(attr::Subgraph);
-  auto kernel = std::make_shared<tensorexpr::TensorExprKernel>(
-      subgraph, custom_lowerings, sym_shapes);
+  IValue sym_strides = node->ival(attr::striding_inputs_desc);
+
+  // Striding Descriptor is serialized on the node as a vector of vector of
+  // strings, translate back to StrideInput enum
+  std::vector<std::vector<std::string>> sym_strides_strs =
+      sym_strides.to<std::vector<std::vector<std::string>>>();
+  std::vector<std::vector<StrideInput>> striding_inputs;
+  for (const auto& vec : sym_strides_strs) {
+    std::vector<StrideInput> input_desc;
+    input_desc.reserve(vec.size());
+    for (const std::string& str : vec) {
+      input_desc.push_back(strideInputFromString(str));
+    }
+    striding_inputs.push_back(input_desc);
+  }
+  std::unordered_map<const Value*, std::vector<StrideInput>> stride_map;
+  size_t index = 0;
+  for (Value* v : subgraph->inputs()) {
+    if (!v->type()->cast<TensorType>()) {
+      continue;
+    }
+    stride_map[v] = striding_inputs[index];
+    index++;
+  }
+
+  std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
+      std::make_shared<tensorexpr::TensorExprKernel>(
+          subgraph,
+          custom_lowerings,
+          sym_shapes,
+          /*pre_alloc*/ false,
+          stride_map);
 
   auto num_subgraph_inputs = subgraph->inputs().size();
   return [kernel, num_subgraph_inputs](Stack& stack) {
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.h b/torch/csrc/jit/passes/tensorexpr_fuser.h
index 0f917e3..01fedc4 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.h
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.h
@@ -10,9 +10,13 @@
 struct Graph;
 
 // Run TensorExpressions-based fuser.
+// If add_composed_op is true, creates a single operation that
+// performs both the runtime check that types align
+// and then the dispatch to the kernel/unoptimized graph
 TORCH_API void FuseTensorExprs(
     std::shared_ptr<Graph>& graph,
-    size_t min_group_size = 2);
+    size_t min_group_size = 2,
+    bool add_composed_op = false);
 
 TORCH_API void setTensorExprFuserEnabled(bool val);
 TORCH_API bool tensorExprFuserEnabled();
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 7eb4fde..2bf9b82 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -826,6 +826,12 @@
       .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed)
       .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
       .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
+      .def(
+          "_jit_set_texpr_dynamic_shape_enabled",
+          &setTensorExprDynamicShapeFusionEnabled)
+      .def(
+          "_jit_texpr_dynamic_shape_enabled",
+          &tensorExprDynamicShapeFusionEnabled)
       .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
       .def(
           "_jit_set_te_generate_block_code",
diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp
index 6080cce..6252577 100644
--- a/torch/csrc/jit/runtime/static/fusion.cpp
+++ b/torch/csrc/jit/runtime/static/fusion.cpp
@@ -330,7 +330,10 @@
   GRAPH_DEBUG("Graph before tracing: ", graph);
   auto traced_graph = TraceGraph(graph, sample_inputs);
   GRAPH_DEBUG("Graph after tracing: ", traced_graph);
-  FuseTensorExprs(traced_graph);
+  FuseTensorExprs(
+      traced_graph,
+      /*min_group_size*/ 2,
+      /*add_composed_op*/ true);
   graph->block()->clear();
   graph->block()->cloneFrom(traced_graph->block(), nullptr);
   GRAPH_DUMP("Graph after fusion: ", graph);
diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h
index 8b20412..3115527 100644
--- a/torch/csrc/jit/tensorexpr/codegen.h
+++ b/torch/csrc/jit/tensorexpr/codegen.h
@@ -118,6 +118,7 @@
   BufferArg(Tensor tensor) : buf_(tensor.buf()) {}
   BufferArg(const VarHandle& var) : var_(var.node()), isVar_(true) {}
   BufferArg(const BufHandle& buf) : buf_(buf.node()) {}
+  BufferArg(const BufPtr& buf) : buf_(buf) {}
 
   VarPtr var() const {
     return isVar_ ? var_ : buf_->base_handle();
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 81ca890..04f8d89 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -8,7 +8,9 @@
 #include <c10/util/irange.h>
 #include <c10/util/string_utils.h>
 #include <torch/csrc/jit/jit_log.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
 #include <torch/csrc/jit/tensorexpr/analysis.h>
+#include <torch/csrc/jit/tensorexpr/expr.h>
 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
@@ -903,102 +905,159 @@
   return dims;
 }
 
-std::vector<ExprHandle> TensorExprKernel::computeInputTensorDims(
-    const torch::jit::Value* input) {
-  auto tt = input->type()->expect<TensorType>();
+ExprHandle TensorExprKernel::getStrideArg(
+    size_t tensor_input_index,
+    size_t stride_index) {
+  auto it = strideArgToVar_.find(
+      std::pair<size_t, size_t>(tensor_input_index, stride_index));
+  if (it == strideArgToVar_.end()) {
+    VarHandle var(
+        "stride_arg" + std::to_string(tensor_input_index) + "_" +
+            std::to_string(stride_index),
+        kLong);
+    strideArgToVar_[std::pair<size_t, size_t>(
+        tensor_input_index, stride_index)] = var;
+    return std::move(var);
+  }
+  return it->second;
+}
 
-  // First case - static shapes
+std::vector<ExprHandle> TensorExprKernel::getInputStrides(
+    const torch::jit::Value* input,
+    const std::vector<ExprHandle>& inputTensorDims) {
+  std::vector<ExprHandle> inputTensorStrides;
   if (input->isCompleteTensor()) {
-    if (isContiguous(input)) {
-      return toExprHandles(*tt->sizes().concrete_sizes());
+    auto const strides =
+        input->type()->expect<TensorType>()->strides().concrete_sizes();
+    std::vector<ExprHandle> inputTensorStrides;
+    for (size_t stride : *strides) {
+      inputTensorStrides.push_back(LongImm::make(stride));
     }
+    return inputTensorStrides;
+  }
 
-    // Non-contiguous tensors are represented as 1-d buffers in NNC
-    ExprHandle flat_size = 1;
-    for (size_t i = 0; i < *tt->sizes().size(); i++) {
-      auto size = *tt->sizes()[i];
-      if (size == 0) {
-        flat_size = 0;
-        break;
+  size_t rank = inputTensorDims.size();
+  TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
+  std::vector<StrideInput>& stride_input = symbolic_strides_[input];
+  if (stride_input.size() == 1 &&
+      (stride_input[0] == StrideInput::TENSOR_CONT_CHANNELS_LAST ||
+       stride_input[0] == StrideInput::TENSOR_CONT)) {
+    auto strides = stride_input[0] == StrideInput::TENSOR_CONT
+        ? make_contiguous_strides(inputTensorDims)
+        : make_channels_last_strides(inputTensorDims);
+    return fmap(strides, [&](ExprPtr stride) { return ExprHandle(stride); });
+  }
+
+  inputTensorStrides.resize(rank);
+  std::vector<bool> stride_set;
+  for (size_t i = 0; i < rank; ++i) {
+    stride_set.push_back(false);
+  }
+  // first, generate non-dependent values
+  size_t generated_strides = 0;
+  for (const auto i : c10::irange(rank)) {
+    if (stride_input[i] == torch::jit::StrideInput::S_ONE) {
+      inputTensorStrides[i] = LongImm::make(1);
+      stride_set[i] = true;
+      generated_strides++;
+    } else if (stride_input[i] == torch::jit::StrideInput::S_AS_ARG) {
+      size_t input_index = input->offset();
+      inputTensorStrides[i] = getStrideArg(input_index, i);
+      stride_set[i] = true;
+      generated_strides++;
+    }
+  }
+  // Contiguous and Transposed Contiguous depend on adjacent values
+  while (generated_strides != rank) {
+    for (int i = static_cast<int>(rank) - 1; i >= 0; i--) {
+      if (stride_input[i] == torch::jit::StrideInput::S_CONT &&
+          stride_set[i + 1]) {
+        inputTensorStrides[i] =
+            inputTensorStrides[i + 1] * inputTensorDims[i + 1];
+
+        stride_set[i] = true;
+        generated_strides++;
       }
-      flat_size = flat_size + (size - 1) * *tt->strides()[i];
     }
-    flat_size = IRSimplifier::simplify(flat_size);
-    return {flat_size};
+    for (int i = 0; i < rank; i++) {
+      if (stride_input[i] == torch::jit::StrideInput::S_TRAN_CONT &&
+          stride_set[i - 1]) {
+        inputTensorStrides[i] =
+            inputTensorStrides[i - 1] * inputTensorDims[i - 1];
+        stride_set[i] = true;
+        generated_strides++;
+      }
+    }
   }
-
-  // Second case - symbolic shapes
-  // We only handle symbolic shape input tensors that are contiguous.
-  // TODO: Handle strided tensors with symbolic shapes.
-  auto const& symbolicShape = tt->symbolic_sizes();
-  auto rank = symbolicShape.rank();
-  if (!rank) {
-    throw std::runtime_error("Symbolic shapes must have static ranks.");
-  }
-  std::vector<ExprHandle> inputTensorDims;
-  for (const auto i : c10::irange(*rank)) {
-    inputTensorDims.emplace_back(getVarForShape(symbolicShape[i]));
-  }
-  return inputTensorDims;
+  return inputTensorStrides;
 }
 
 Tensor TensorExprKernel::bindInput(const torch::jit::Value* input) {
   auto const& t = input->type();
   auto const& outputs = input->owningGraph()->outputs();
   std::unordered_set<const Value*> outputs_set(outputs.begin(), outputs.end());
+
   Tensor result(nullptr, nullptr);
   switch (t->kind()) {
     case TypeKind::TensorType: {
       auto tt = input->type()->cast<TensorType>();
-      auto inputTensorDims = computeInputTensorDims(input);
-
-      BufHandle inBuffer(
-          "t" + input_name_map_[input],
-          inputTensorDims,
-          ToDtype(static_cast<ScalarType>(*tt->scalarType())));
-      bufferArgs_.emplace_back(inBuffer);
+      bool contiguous_concrete_tensor =
+          (input->isCompleteTensor() && isContiguous(input));
+      bool contiguous_strided_tensor = symbolic_strides_.count(input) &&
+          symbolic_strides_[input].size() == 1 &&
+          symbolic_strides_[input][0] == torch::jit::StrideInput::TENSOR_CONT;
 
       // We don't need to copy the input if:
       //  1) it is not an output AND
       //  2) it is contiguous
-      //
-      // For static shapes we can check (2) directly, for symbolic shapes we
-      // currently *assume* it is contiguous.
-      //
-      // TODO: update this logic as soon as we start supporting symbolic
-      // strides.
-      if (!outputs_set.count(input) &&
-          (!input->isCompleteTensor() || isContiguous(input))) {
+      bool contiguous = contiguous_concrete_tensor || contiguous_strided_tensor;
+      if (!outputs_set.count(input) && contiguous) {
+        BufHandle inBuffer(
+            "t" + input_name_map_[input],
+            sizesFromSymbolicShape(tt->symbolic_sizes()),
+            ToDtype(static_cast<ScalarType>(*tt->scalarType())));
         bufs_.emplace(input, inBuffer.node());
+        bufferArgs_.emplace_back(inBuffer);
         break;
       }
 
-      // Symbolic shapes case:
-      if (!input->isCompleteTensor()) {
-        TORCH_INTERNAL_ASSERT(outputs_set.count(input));
-        result = Compute(
-            "input" + c10::to_string(bufs_.size() + 1),
-            c10::fmap<DimArg>(inputTensorDims),
-            [&](const std::vector<VarHandle>& axes) {
-              return inBuffer.load(axes);
-            });
-        bufs_.emplace(input, result.buf());
-        break;
+      // if the input isn't contiguous or is an output,
+      // write strided input into  contiguous buffer that is
+      // then used in all further compute
+      std::vector<DimArg> inputTensorDims;
+      auto size_handles = sizesFromSymbolicShape(tt->symbolic_sizes());
+      for (size_t i = 0; i < size_handles.size(); i++) {
+        auto size = size_handles[i];
+        inputTensorDims.emplace_back(DimArg(size, "i" + c10::to_string(i)));
       }
+      auto inputTensorStrides = getInputStrides(input, size_handles);
+      ExprHandle flat_size = 1;
+      for (size_t i = 0; i < size_handles.size(); ++i) {
+        auto size = size_handles[i];
+        if (size.AsNode<LongImm>() && immediateAs<int64_t>(size.node()) == 0) {
+          flat_size = 0;
+          break;
+        }
+        flat_size = flat_size + (size - 1) * inputTensorStrides[i];
+      }
+      flat_size = IRSimplifier::simplify(flat_size);
+      BufHandle inBuffer(
+          "t" + input_name_map_[input],
+          {flat_size},
+          ToDtype(static_cast<ScalarType>(*tt->scalarType())));
 
-      // Static shapes, non-contiguous case:
-      auto const strides = tt->strides();
       result = Compute(
           "input" + c10::to_string(bufs_.size() + 1),
-          c10::fmap<DimArg>(*tt->sizes().concrete_sizes()),
+          inputTensorDims,
           [&](const std::vector<VarHandle>& axes) {
             ExprHandle idx = 0;
             for (size_t i = 0; i < axes.size(); i++) {
-              idx = idx + axes[i] * *strides[i];
+              idx = idx + axes[i] * inputTensorStrides[i];
             }
             return inBuffer.load(idx);
           });
       bufs_.emplace(input, result.buf());
+      bufferArgs_.emplace_back(inBuffer);
       break;
     }
     case TypeKind::FloatType: {
@@ -1213,6 +1272,8 @@
 
 BlockPtr TensorExprKernel::bindAllInputs() {
   std::vector<CodeGen::BufferArg> symbolic_shape_args;
+  std::vector<CodeGen::BufferArg> symbolic_stride_args;
+
   auto symbolic_shape_inputs_start_pos =
       nInputs_ - symbolic_shape_inputs_.size();
   if (has_symbolic_shapes_) {
@@ -1231,6 +1292,7 @@
     // their symbolic sizes needs to be associated with these variables we
     // create for the symbolic input params.
     symbolic_shape_args.reserve(symbolic_shape_inputs_.size());
+
     for (size_t i = symbolic_shape_inputs_start_pos; i < nInputs_; ++i) {
       auto input = graph_->inputs()[i];
       if (input->type()->kind() != TypeKind::IntType) {
@@ -1247,6 +1309,25 @@
       shapeSymbolToVar_[symbolic_shape_inputs_[i]] =
           scalars_[graph_->inputs()[symbolic_shape_inputs_start_pos + i]];
     }
+
+    // Next, process symbolic input params and create an argument for symbolic
+    for (size_t i = 0; i < symbolic_shape_inputs_start_pos; ++i) {
+      auto input = graph_->inputs()[i];
+      auto tt = input->type()->cast<TensorType>();
+      if (!tt) {
+        continue;
+      }
+      TORCH_INTERNAL_ASSERT(symbolic_strides_.count(input));
+      auto symbolic_stride = symbolic_strides_[input];
+      for (size_t j = 0; j < symbolic_stride.size(); ++j) {
+        if (symbolic_stride[j] == torch::jit::StrideInput::S_AS_ARG) {
+          VarHandle v("v" + input_name_map_[input], kLong);
+          symbolic_stride_args.emplace_back(v);
+          strideArgToVar_[{i, j}] = v;
+          input_stride_args_.emplace_back(i, j);
+        }
+      }
+    }
   }
 
   // Block to collect the Stmts corresponding to all tensors.
@@ -1265,6 +1346,13 @@
       bufferArgs_.end(),
       symbolic_shape_args.begin(),
       symbolic_shape_args.end());
+
+  // Now, add all the variables corresponding to symbolic stride inputs
+  bufferArgs_.insert(
+      bufferArgs_.end(),
+      symbolic_stride_args.begin(),
+      symbolic_stride_args.end());
+
   return block;
 }
 
@@ -1375,14 +1463,19 @@
     const std::string& kernel_func_name,
     std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings,
     std::vector<int64_t> symbolic_shape_inputs,
-    bool pre_alloc /*= false*/)
+    bool pre_alloc /*= false*/,
+    std::unordered_map<
+        const torch::jit::Value*,
+        std::vector<torch::jit::StrideInput>> symbolic_strides)
     : graph_(subgraph),
       code_(subgraph, ""),
       symbolic_shape_inputs_(std::move(symbolic_shape_inputs)),
       custom_lowerings_(std::move(custom_lowerings)),
       pre_alloc_(pre_alloc),
-      kernel_func_name_(kernel_func_name) {
+      kernel_func_name_(kernel_func_name),
+      symbolic_strides_(std::move(symbolic_strides)) {
   allow_fallback_ = fallbackAllowed();
+
   if (!allow_fallback_) {
     compile();
     return;
@@ -1447,7 +1540,8 @@
   // TODO: preallocate `runArgs` during compilation and fill in values where
   // possible (e.g. for constant tensors)
   std::vector<CodeGen::CallArg> runArgs;
-  runArgs.reserve(inputs.size() + bufOutputs_.size());
+  runArgs.reserve(
+      inputs.size() + input_stride_args_.size() + bufOutputs_.size());
 
   for (auto& input : inputs) {
     if (input.isInt()) {
@@ -1461,6 +1555,13 @@
 
   if (has_symbolic_shapes_) {
     updateOutputSizesAndStrides(inputs);
+
+    // add stride args
+    for (const auto& input_stride_arg : input_stride_args_) {
+      runArgs.emplace_back(
+          inputs[input_stride_arg.first].toTensor().strides().at(
+              input_stride_arg.second));
+    }
   }
 
   for (size_t i = 0, e = bufOutputs_.size(); i < e; ++i) {
diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h
index aa54cdb..5b43eee 100644
--- a/torch/csrc/jit/tensorexpr/kernel.h
+++ b/torch/csrc/jit/tensorexpr/kernel.h
@@ -2,6 +2,7 @@
 
 #include <c10/util/variant.h>
 #include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
 #include <torch/csrc/jit/runtime/interpreter.h>
 #include <torch/csrc/jit/tensorexpr/analysis.h>
@@ -13,6 +14,14 @@
 namespace jit {
 namespace tensorexpr {
 
+struct SmallSizeTPairHash {
+ public:
+  std::size_t operator()(const std::pair<size_t, size_t>& x) const {
+    // hashing input index and then dim index
+    return x.first * 128 + x.second;
+  }
+};
+
 // Returns true if the TE fuser supports this conv2d.
 bool conv2dIsSupportedJit(const Node* node);
 // Returns true if the TE fuser supports this matmul.
@@ -114,20 +123,27 @@
       std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
           {},
       std::vector<int64_t> symbolic_shape_inputs = {},
-      bool pre_alloc = false);
+      bool pre_alloc = false,
+      std::unordered_map<
+          const torch::jit::Value*,
+          std::vector<torch::jit::StrideInput>> symbolic_strides = {});
 
   explicit TensorExprKernel(
       const std::shared_ptr<Graph>& subgraph,
       std::unordered_map<c10::Symbol, NNCLoweringFunction> custom_lowerings =
           {},
       std::vector<int64_t> symbolic_shape_inputs = {},
-      bool pre_alloc = false)
+      bool pre_alloc = false,
+      std::unordered_map<
+          const torch::jit::Value*,
+          std::vector<torch::jit::StrideInput>> symbolic_strides = {})
       : TensorExprKernel(
             subgraph,
             SubgraphUtils::generateNameForGraph(subgraph),
             custom_lowerings,
             symbolic_shape_inputs,
-            pre_alloc) {}
+            pre_alloc,
+            symbolic_strides) {}
 
   void run(Stack& stack);
   void runFast(
@@ -243,8 +259,12 @@
   ExprHandle getVarForShape(const c10::ShapeSymbol& ss);
   std::vector<ExprHandle> computeInputTensorDims(
       const torch::jit::Value* input);
+  ExprHandle getStrideArg(size_t tensor_input, size_t stride_index);
   std::vector<ExprHandle> sizesFromSymbolicShape(
       const c10::SymbolicShape& shape);
+  std::vector<ExprHandle> getInputStrides(
+      const torch::jit::Value* input,
+      const std::vector<ExprHandle>& inputTensorDims);
 
   int64_t nInputs_ = 0;
   int64_t nOutputs_ = 0;
@@ -284,6 +304,17 @@
   StmtPtr stmt_ = nullptr;
   bool pre_alloc_{false};
   std::string kernel_func_name_;
+
+  // index of stack, stride index of tensor that will be appended as a codegen
+  // arg
+  std::vector<std::pair<size_t, size_t>> input_stride_args_;
+  // map from <input index, tensor dimension> to stride as arg VarHandle
+  std::unordered_map<std::pair<size_t, size_t>, VarHandle, SmallSizeTPairHash>
+      strideArgToVar_;
+  std::unordered_map<
+      const torch::jit::Value*,
+      std::vector<torch::jit::StrideInput>>
+      symbolic_strides_;
 };
 
 TORCH_API int& getTECudaPointwiseLoopLevels();