[Inductor] Support vectorized transpose in CPP backend (#91532)

Fix https://github.com/pytorch/torchdynamo/issues/1915
This PR adds the vectorization support for transposed operations in TorchInductor CPP backend. It contains the following changes:
1. `CppTile2DKernelChecker` is added to check the eligibility of applying the optimization. We only addresss a narrow set of situations. All of the following conditions should be met: 1) There exists one and only one fp32 load/store with outer loop var having contiguous buffer accesses. 2) When a load/store doesn't have contiguous access in an outer loop var, the access should be vectorizable from the inner-most dim. 3) No reduction. More scenarios/operations would be supported in the future PRs.
2. If `CppTile2DKernelChecker` reports the optimization is doable, `CppKernelProxy` would split/tile the loops from both the outer loop var having contiguous buffer access and the inner-most loop var.
3. The main loop split from the outer loop var is further split at the inner-most level and then handled by `CppTile2DKernel` and `CppTile2DTailKernel` which generate the transposed load/store. The former kernel does the vectorized transposed load/store on tiles and then does vectorized load/store/compute along the inner-most loop axis. The vectorized transpose micro-kernel implementation borrows/refers to that from FBGEMM. The latter kernel simply does scalar operations.
4. The tail loop split from the outer loop var directly calls `CppKernel` with scalar operations.

Next steps:
1. Support vectorized transpose with smaller tile size at one dim but bigger tile size at the other, e.g., 3x784.
2. Support reduction vectorized on the outer loop var (contiguous from outer loop var, not with inner-most loop var)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91532
Approved by: https://github.com/EikanWang, https://github.com/jansel
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
index b12e8b0..4f1fc74 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
@@ -453,6 +453,105 @@
   return _mm256_fmsub_ps(a, b, c);
 }
 
+// Used by Inductor CPP codegen
+template<>
+inline void transpose_mxn<float, 8, 8>(
+    const float* src,
+    int64_t ld_src,
+    float* dst,
+    int64_t ld_dst) {
+  // load from src to registers
+  // a: a0  a1  a2  a3  a4  a5  a6  a7
+  // b: b0  b1  b2  b3  b4  b5  b6  b7
+  // c: c0  c1  c2  c3  c4  c5  c6  c7
+  // d: d0  d1  d2  d3  d4  d5  d6  d7
+  // e: e0  e1  e2  e3  e4  e5  e6  e7
+  // f: f0  f1  f2  f3  f4  f5  f6  f7
+  // g: g0  g1  g2  g3  g4  g5  g6  g7
+  // h: h0  h1  h2  h3  h4  h5  h6  h7
+  __m256 a = _mm256_loadu_ps(&src[0 * ld_src]);
+  __m256 b = _mm256_loadu_ps(&src[1 * ld_src]);
+  __m256 c = _mm256_loadu_ps(&src[2 * ld_src]);
+  __m256 d = _mm256_loadu_ps(&src[3 * ld_src]);
+  __m256 e = _mm256_loadu_ps(&src[4 * ld_src]);
+  __m256 f = _mm256_loadu_ps(&src[5 * ld_src]);
+  __m256 g = _mm256_loadu_ps(&src[6 * ld_src]);
+  __m256 h = _mm256_loadu_ps(&src[7 * ld_src]);
+
+  __m256 ta, tb, tc, td, te, tf, tg, th;
+  // unpacking and interleaving 32-bit elements
+  // a0  b0  a1  b1  a4  b4  a5  b5
+  // a2  b2  a3  b3  a6  b6  a7  b7
+  // c0  d0  c1  d1 ...
+  // c2  d2  c3  d3 ...
+  // e0  f0  e1  f1 ...
+  // e2  f2  e3  f3 ...
+  // g0  h0  g1  h1 ...
+  // g2  h2  g3  h3 ...
+  ta = _mm256_unpacklo_ps(a, b);
+  tb = _mm256_unpackhi_ps(a, b);
+  tc = _mm256_unpacklo_ps(c, d);
+  td = _mm256_unpackhi_ps(c, d);
+  te = _mm256_unpacklo_ps(e, f);
+  tf = _mm256_unpackhi_ps(e, f);
+  tg = _mm256_unpacklo_ps(g, h);
+  th = _mm256_unpackhi_ps(g, h);
+
+  // unpacking and interleaving 64-bit elements
+  //  a0  b0  c0  d0  a4  b4  c4  d4
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  e0  f0  g0  h0  e4  f4  g4  h4
+  //  e1  f1  g1  h1 ...
+  //  e2  f2  g2  h2 ...
+  //  e3  f3  g3  h3 ...
+  a = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
+  b = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(ta), _mm256_castps_pd(tc)));
+  c = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
+  d = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(tb), _mm256_castps_pd(td)));
+  e = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
+  f = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(te), _mm256_castps_pd(tg)));
+  g = _mm256_castpd_ps(
+      _mm256_unpacklo_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
+  h = _mm256_castpd_ps(
+      _mm256_unpackhi_pd(_mm256_castps_pd(tf), _mm256_castps_pd(th)));
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  e0  f0  g0  h0
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  a4  b4  c4  d4 ...
+  //  a5  b5  c5  d5 ...
+  //  a6  b6  c6  d6 ...
+  //  a7  b7  c7  d7 ...
+  ta = _mm256_permute2f128_ps(a, e, 0x20);
+  tb = _mm256_permute2f128_ps(b, f, 0x20);
+  tc = _mm256_permute2f128_ps(c, g, 0x20);
+  td = _mm256_permute2f128_ps(d, h, 0x20);
+  te = _mm256_permute2f128_ps(a, e, 0x31);
+  tf = _mm256_permute2f128_ps(b, f, 0x31);
+  tg = _mm256_permute2f128_ps(c, g, 0x31);
+  th = _mm256_permute2f128_ps(d, h, 0x31);
+
+  // store from registers to dst
+  _mm256_storeu_ps(&dst[0 * ld_dst], ta);
+  _mm256_storeu_ps(&dst[1 * ld_dst], tb);
+  _mm256_storeu_ps(&dst[2 * ld_dst], tc);
+  _mm256_storeu_ps(&dst[3 * ld_dst], td);
+  _mm256_storeu_ps(&dst[4 * ld_dst], te);
+  _mm256_storeu_ps(&dst[5 * ld_dst], tf);
+  _mm256_storeu_ps(&dst[6 * ld_dst], tg);
+  _mm256_storeu_ps(&dst[7 * ld_dst], th);
+}
+
 #endif
 
 }}}
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
index 7cbab8d..b4d8f49 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
@@ -499,6 +499,223 @@
   return _mm512_fmsub_ps(a, b, c);
 }
 
+// TODO(jgong5): rewrite with ATEN vectorized (need to add unpack and shuffle)
+// Used by Inductor CPP codegen
+// Code referred to FBGEMM:
+// https://github.com/pytorch/FBGEMM/blob/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9/src/UtilsAvx512.cc#LL19C6-L19C6
+// 16 * 6 = 96 instructions
+template<>
+inline void transpose_mxn<float, 16, 16>(
+    const float* src,
+    int64_t ld_src,
+    float* dst,
+    int64_t ld_dst) {
+  // load from src to registers
+  // a: a0  a1  a2  a3  a4  a5  a6  a7  a8  a9  a10 a11 a12 a13 a14 a15
+  // b: b0  b1  b2  b3  b4  b5  b6  b7  b8  b9  b10 b11 b12 b13 b14 b15
+  // c: c0  c1  c2  c3  c4  c5  c6  c7  c8  c9  c10 c11 c12 c13 c14 c15
+  // d: d0  d1  d2  d3  d4  d5  d6  d7  d8  d9  d10 d11 d12 d13 d14 d15
+  // e: e0  e1  e2  e3  e4  e5  e6  e7  e8  e9  e10 e11 e12 e13 e14 e15
+  // f: f0  f1  f2  f3  f4  f5  f6  f7  f8  f9  f10 f11 f12 f13 f14 f15
+  // g: g0  g1  g2  g3  g4  g5  g6  g7  g8  g9  g10 g11 g12 g13 g14 g15
+  // h: h0  h1  h2  h3  h4  h5  h6  h7  h8  h9  h10 h11 h12 h13 h14 h15
+  // i: i0  i1  i2  i3  i4  i5  i6  i7  i8  i9  i10 i11 i12 i13 i14 i15
+  // j: j0  j1  j2  j3  j4  j5  j6  j7  j8  j9  j10 j11 j12 j13 j14 j15
+  // k: k0  k1  k2  k3  k4  k5  k6  k7  k8  k9  k10 k11 k12 k13 k14 k15
+  // l: l0  l1  l2  l3  l4  l5  l6  l7  l8  l9  l10 l11 l12 l13 l14 l15
+  // m: m0  m1  m2  m3  m4  m5  m6  m7  m8  m9  m10 m11 m12 m13 m14 m15
+  // n: n0  n1  n2  n3  n4  n5  n6  n7  n8  n9  n10 n11 n12 n13 n14 n15
+  // o: o0  o1  o2  o3  o4  o5  o6  o7  o8  o9  o10 o11 o12 o13 o14 o15
+  // p: p0  p1  p2  p3  p4  p5  p6  p7  p8  p9  p10 p11 p12 p13 p14 p15
+  __m512 a = _mm512_loadu_ps(&src[0 * ld_src]);
+  __m512 b = _mm512_loadu_ps(&src[1 * ld_src]);
+  __m512 c = _mm512_loadu_ps(&src[2 * ld_src]);
+  __m512 d = _mm512_loadu_ps(&src[3 * ld_src]);
+  __m512 e = _mm512_loadu_ps(&src[4 * ld_src]);
+  __m512 f = _mm512_loadu_ps(&src[5 * ld_src]);
+  __m512 g = _mm512_loadu_ps(&src[6 * ld_src]);
+  __m512 h = _mm512_loadu_ps(&src[7 * ld_src]);
+  __m512 i = _mm512_loadu_ps(&src[8 * ld_src]);
+  __m512 j = _mm512_loadu_ps(&src[9 * ld_src]);
+  __m512 k = _mm512_loadu_ps(&src[10 * ld_src]);
+  __m512 l = _mm512_loadu_ps(&src[11 * ld_src]);
+  __m512 m = _mm512_loadu_ps(&src[12 * ld_src]);
+  __m512 n = _mm512_loadu_ps(&src[13 * ld_src]);
+  __m512 o = _mm512_loadu_ps(&src[14 * ld_src]);
+  __m512 p = _mm512_loadu_ps(&src[15 * ld_src]);
+
+  __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq;
+  // unpacking and interleaving 32-bit elements
+  // a0  b0  a1  b1  a4  b4  a5  b5  a8  b8  a9  b9  a12  b12 a13 b13
+  // a2  b2  a3  b3  a6  b6  a7  b7  a10 b10 a11 b11 a14  b14 a15 b15
+  // c0  d0  c1  d1 ...
+  // c2  d2  c3  d3 ...
+  // e0  f0  e1  f1 ...
+  // e2  f2  e3  f3 ...
+  // g0  h0  g1  h1 ...
+  // g2  h2  g3  h3 ...
+  // i0  ...
+  // i2  ...
+  // k0  ...
+  // k2  ...
+  // m0  ...
+  // m2  ...
+  // o0  ...
+  // o1  ...
+  ta = _mm512_unpacklo_ps(a, b);
+  tb = _mm512_unpackhi_ps(a, b);
+  tc = _mm512_unpacklo_ps(c, d);
+  td = _mm512_unpackhi_ps(c, d);
+  te = _mm512_unpacklo_ps(e, f);
+  tf = _mm512_unpackhi_ps(e, f);
+  tg = _mm512_unpacklo_ps(g, h);
+  th = _mm512_unpackhi_ps(g, h);
+  ti = _mm512_unpacklo_ps(i, j);
+  tj = _mm512_unpackhi_ps(i, j);
+  tk = _mm512_unpacklo_ps(k, l);
+  tl = _mm512_unpackhi_ps(k, l);
+  tm = _mm512_unpacklo_ps(m, n);
+  tn = _mm512_unpackhi_ps(m, n);
+  to = _mm512_unpacklo_ps(o, p);
+  tq = _mm512_unpackhi_ps(o, p);
+
+  // unpacking and interleaving 64-bit elements
+  //  a0  b0  c0  d0  a4  b4  c4  d4  a8  b8  c8  d8  a12 b12 c12 d12
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  e0  f0  g0  h0  e4  f4  g4  h4  e8  f8  g8  h8  e12 f12 g12 h12
+  //  e1  f1  g1  h1 ...
+  //  e2  f2  g2  h2 ...
+  //  e3  f3  g3  h3 ...
+  //  i0  j0  k0  l0 ...
+  //  i1  j1  k1  l1 ...
+  //  i2  j2  k2  l2 ...
+  //  i3  j3  k3  l3 ...
+  //  m0  n0  o0  p0 ...
+  //  m1  n1  o1  p1 ...
+  //  m2  n2  o2  p2 ...
+  //  m3  n3  o3  p3 ...
+  a = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+  b = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(ta), _mm512_castps_pd(tc)));
+  c = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+  d = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tb), _mm512_castps_pd(td)));
+  e = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+  f = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(te), _mm512_castps_pd(tg)));
+  g = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+  h = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tf), _mm512_castps_pd(th)));
+  i = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+  j = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(ti), _mm512_castps_pd(tk)));
+  k = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+  l = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tj), _mm512_castps_pd(tl)));
+  m = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+  n = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tm), _mm512_castps_pd(to)));
+  o = _mm512_castpd_ps(
+      _mm512_unpacklo_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+  p = _mm512_castpd_ps(
+      _mm512_unpackhi_pd(_mm512_castps_pd(tn), _mm512_castps_pd(tq)));
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  a8  b8  c8  d8  e0  f0  g0  h0  e8  f8  g8  h8
+  //  a1  b1  c1  d1 ...
+  //  a2  b2  c2  d2 ...
+  //  a3  b3  c3  d3 ...
+  //  a4  b4  c4  d4 ...
+  //  a5  b5  c5  d5 ...
+  //  a6  b6  c6  d6 ...
+  //  a7  b7  c7  d7 ...
+  //  i0  j0  k0  l0  i8  j8  k8  l8  m0  n0  o0  p0  m8  n8  o8  p8
+  //  i1  j1  k1  l1 ...
+  //  i2  j2  k2  l2 ...
+  //  i3  j3  k3  l3 ...
+  //  i4  j4  k4  l4 ...
+  //  i5  j5  k5  l5 ...
+  //  i6  j6  k6  l6 ...
+  //  i7  j7  k7  l7 ...
+  ta = _mm512_shuffle_f32x4(a, e, 0x88);
+  tb = _mm512_shuffle_f32x4(b, f, 0x88);
+  tc = _mm512_shuffle_f32x4(c, g, 0x88);
+  td = _mm512_shuffle_f32x4(d, h, 0x88);
+  te = _mm512_shuffle_f32x4(a, e, 0xdd);
+  tf = _mm512_shuffle_f32x4(b, f, 0xdd);
+  tg = _mm512_shuffle_f32x4(c, g, 0xdd);
+  th = _mm512_shuffle_f32x4(d, h, 0xdd);
+  ti = _mm512_shuffle_f32x4(i, m, 0x88);
+  tj = _mm512_shuffle_f32x4(j, n, 0x88);
+  tk = _mm512_shuffle_f32x4(k, o, 0x88);
+  tl = _mm512_shuffle_f32x4(l, p, 0x88);
+  tm = _mm512_shuffle_f32x4(i, m, 0xdd);
+  tn = _mm512_shuffle_f32x4(j, n, 0xdd);
+  to = _mm512_shuffle_f32x4(k, o, 0xdd);
+  tq = _mm512_shuffle_f32x4(l, p, 0xdd);
+
+  //  shuffle 128-bits (composed of 4 32-bit elements)
+  //  a0  b0  c0  d0  ...  o0
+  //  a1  b1  c1  d1  ...  o1
+  //  a2  b2  c2  d2  ...  o2
+  //  a3  b3  c3  d3  ...  o3
+  //  a4  ...
+  //  a5  ...
+  //  a6  ...
+  //  a7  ...
+  //  a8  ...
+  //  a9  ...
+  //  a10 ...
+  //  a11 ...
+  //  a12 ...
+  //  a13 ...
+  //  a14 ...
+  //  a15 b15 c15 d15 ...  o15
+  a = _mm512_shuffle_f32x4(ta, ti, 0x88);
+  b = _mm512_shuffle_f32x4(tb, tj, 0x88);
+  c = _mm512_shuffle_f32x4(tc, tk, 0x88);
+  d = _mm512_shuffle_f32x4(td, tl, 0x88);
+  e = _mm512_shuffle_f32x4(te, tm, 0x88);
+  f = _mm512_shuffle_f32x4(tf, tn, 0x88);
+  g = _mm512_shuffle_f32x4(tg, to, 0x88);
+  h = _mm512_shuffle_f32x4(th, tq, 0x88);
+  i = _mm512_shuffle_f32x4(ta, ti, 0xdd);
+  j = _mm512_shuffle_f32x4(tb, tj, 0xdd);
+  k = _mm512_shuffle_f32x4(tc, tk, 0xdd);
+  l = _mm512_shuffle_f32x4(td, tl, 0xdd);
+  m = _mm512_shuffle_f32x4(te, tm, 0xdd);
+  n = _mm512_shuffle_f32x4(tf, tn, 0xdd);
+  o = _mm512_shuffle_f32x4(tg, to, 0xdd);
+  p = _mm512_shuffle_f32x4(th, tq, 0xdd);
+
+  // store from registers to dst
+  _mm512_storeu_ps(&dst[0 * ld_dst], a);
+  _mm512_storeu_ps(&dst[1 * ld_dst], b);
+  _mm512_storeu_ps(&dst[2 * ld_dst], c);
+  _mm512_storeu_ps(&dst[3 * ld_dst], d);
+  _mm512_storeu_ps(&dst[4 * ld_dst], e);
+  _mm512_storeu_ps(&dst[5 * ld_dst], f);
+  _mm512_storeu_ps(&dst[6 * ld_dst], g);
+  _mm512_storeu_ps(&dst[7 * ld_dst], h);
+  _mm512_storeu_ps(&dst[8 * ld_dst], i);
+  _mm512_storeu_ps(&dst[9 * ld_dst], j);
+  _mm512_storeu_ps(&dst[10 * ld_dst], k);
+  _mm512_storeu_ps(&dst[11 * ld_dst], l);
+  _mm512_storeu_ps(&dst[12 * ld_dst], m);
+  _mm512_storeu_ps(&dst[13 * ld_dst], n);
+  _mm512_storeu_ps(&dst[14 * ld_dst], o);
+  _mm512_storeu_ps(&dst[15 * ld_dst], p);
+}
+
 #endif
 
 }}}
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index abf106e..0a9d303 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -1027,4 +1027,15 @@
   return Vectorized<T>::loadu(static_cast<void*>(output));
 }
 
+// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
+// dimension of `src` and `ld_dst` is the leading dimension of `dst`.
+template <typename T, int M, int N>
+inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
+  for (int i = 0; i < M; i++) {
+    for (int j = 0; j < N; j++) {
+      dst[j*ld_dst + i] = src[i*ld_src + j];
+    }
+  }
+}
+
 }}}
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 88bea7c..380f6fa 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -5535,7 +5535,7 @@
                 traced = make_fx(fn)(x1, x2)
                 compiled = compile_fx_inner(traced, [x1, x2])
                 assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True)
-                assert metrics.generated_cpp_vec_kernel_count == 0
+                assert metrics.generated_cpp_vec_kernel_count == 1
 
                 torch._dynamo.reset()
                 metrics.reset()
@@ -5569,6 +5569,96 @@
                     kernel_profile_events.append(e.name)
             assert len(kernel_profile_events) > 0
 
+        @unittest.skipIf(
+            not codecache.valid_vec_isa_list(), "Does not support vectorization"
+        )
+        def test_channel_shuffle_cl_output(self):
+            """code and shape extracted from shufflenet_v2_x1_0"""
+
+            def channel_shuffle(x, groups):
+                batchsize, num_channels, height, width = x.size()
+                channels_per_group = num_channels // groups
+                x = x.view(batchsize, groups, channels_per_group, height, width)
+                x = torch.transpose(x, 1, 2).contiguous()
+                x = x.view(batchsize, -1, height, width)
+                return x.contiguous(memory_format=torch.channels_last)
+
+            for simdlen in (None, 256, 1):
+                with patch.object(config.cpp, "simdlen", simdlen):
+                    torch._dynamo.reset()
+                    metrics.reset()
+                    x = torch.randn(64, 58, 28, 28)
+                    opt_fn = torch._dynamo.optimize("inductor")(channel_shuffle)
+                    same(channel_shuffle(x, 2), opt_fn(x, 2))
+                    if simdlen != 1:
+                        assert metrics.generated_cpp_vec_kernel_count == 1
+
+        @unittest.skipIf(
+            not codecache.valid_vec_isa_list(), "Does not support vectorization"
+        )
+        def test_transpose_with_norm(self):
+            """a sub-module from TIMM gmlp_s16_224"""
+
+            class Model(torch.nn.Module):
+                def __init__(self):
+                    super(Model, self).__init__()
+                    self.linear = torch.nn.Linear(
+                        in_features=256, out_features=1536, bias=True
+                    )
+                    self.act = torch.nn.GELU()
+                    self.norm = torch.nn.LayerNorm(768)
+                    self.proj = torch.nn.Linear(196, 196)
+                    self.fc = torch.nn.Linear(
+                        in_features=768, out_features=256, bias=True
+                    )
+
+                def forward(self, x):
+                    x = self.linear(x)
+                    x = self.act(x)
+                    u, v = x.chunk(2, dim=-1)
+                    v = self.norm(v)
+                    v = self.proj(v.transpose(-1, -2))
+                    y = u * v.transpose(-1, -2)
+                    return self.fc(y)
+
+            x = torch.randn(128, 196, 256)
+            for simdlen in (None, 256, 1):
+                with patch.object(config.cpp, "simdlen", simdlen):
+                    for eval_mode in [True, False]:
+                        torch._dynamo.reset()
+                        metrics.reset()
+                        m = Model().eval() if eval_mode else Model()
+                        opt_fn = torch._dynamo.optimize("inductor")(m)
+                        same(m(x), opt_fn(x))
+                        if simdlen != 1:
+                            assert metrics.generated_cpp_vec_kernel_count == 7
+
+        @unittest.skipIf(
+            not codecache.valid_vec_isa_list(), "Does not support vectorization"
+        )
+        def test_transpose_copy(self):
+            def fn(a):
+                return a.t().contiguous()
+
+            for simdlen in (None, 256, 1):
+                with patch.object(config.cpp, "simdlen", simdlen):
+                    for shape in (
+                        (7, 7),
+                        (8, 8),
+                        (9, 9),
+                        (16, 16),
+                        (17, 17),
+                        (32, 32),
+                        (33, 33),
+                    ):
+                        torch._dynamo.reset()
+                        metrics.reset()
+                        x = torch.randn(shape)
+                        opt_fn = torch._dynamo.optimize("inductor")(fn)
+                        same(fn(x), opt_fn(x))
+                        if simdlen != 1:
+                            assert metrics.generated_cpp_vec_kernel_count == 1
+
 
 if HAS_CUDA:
     import triton
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 96b5d23..ed98785 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -6,6 +6,7 @@
 from copy import copy, deepcopy
 from pathlib import Path
 from typing import Dict, List
+from unittest.mock import patch
 
 import sympy
 
@@ -613,8 +614,19 @@
         self.reduction_prefix = IndentedBuffer()
         self.reduction_suffix = DeferredIndentedBuffer()
         self.reduction_var_map = {}
+        self.preloads = IndentedBuffer()
+        self.poststores = DeferredIndentedBuffer()
         self.num_threads = num_threads  # num_threads the kernel specialized for
 
+    def scale_index_with_offset(
+        self, index: sympy.Expr, scale, itervar_idx=-1, offset=0
+    ):
+        expanded_index = sympy.expand(index)
+        var = self.itervars[itervar_idx]
+        replacement = {var: var * scale + offset}
+        new_index = sympy_subs(expanded_index, replacement)
+        return new_index
+
     def load(self, name: str, index: sympy.Expr):
         var = self.args.input(name)
         index = self.rename_indexing(index)
@@ -715,10 +727,17 @@
                     stack.enter_context(code.indent())
 
             def gen_kernel(kernel):
-                assert kernel
-                code.splice(kernel.loads)
-                code.splice(kernel.compute)
-                code.splice(kernel.stores)
+                with contextlib.ExitStack() as stack:
+                    assert kernel
+                    if hasattr(kernel, "codegen_inner_loops"):
+                        code.splice(kernel.preloads)
+                        kernel.codegen_inner_loops(code)
+                        stack.enter_context(code.indent())
+                    code.splice(kernel.loads)
+                    code.splice(kernel.compute)
+                    code.splice(kernel.stores)
+                if hasattr(kernel, "codegen_inner_loops"):
+                    code.splice(kernel.poststores)
 
             def gen_loops(loops: List[LoopLevel], in_reduction=False):
                 with contextlib.ExitStack() as stack_outer:
@@ -808,39 +827,34 @@
 class CppVecKernel(CppKernel):
     overrides = CppVecOverrides
 
-    def __init__(self, args, num_threads):
+    def __init__(self, args, num_threads, tiling_factor=0):
         super(CppVecKernel, self).__init__(args, num_threads)
         assert codecache.pick_vec_isa()
-        self.simd_nelements = codecache.pick_vec_isa().nelements()
+        if tiling_factor == 0:
+            tiling_factor = codecache.pick_vec_isa().nelements()
+        self.tiling_factor = tiling_factor
         self.reduction_omp_dec: Dict[str, str] = {}
         self.var_vec_buf_map: Dict[str, str] = {}
         metrics.generated_cpp_vec_kernel_count += 1
 
-    def is_single_step_var(self, var: sympy.Symbol, index: sympy.Expr):
+    def stride_at(self, var: sympy.Symbol, index: sympy.Expr):
         replacement = {var: var + 1}
         new_index = sympy_subs(index, replacement)
-        delta = sympy.simplify(new_index - index)
-        return delta == 1
+        return sympy.simplify(new_index - index)
 
-    def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr):
+    def is_stride1_at(self, var: sympy.Symbol, index: sympy.Expr):
+        return self.stride_at(var, index) == 1
+
+    def is_invariant_under(self, var: sympy.Symbol, index: sympy.Expr):
         expanded_index = sympy.expand(index)
         return not expanded_index.has(var)
 
-    def transform_index(self, index: sympy.Expr):
-        expanded_index = sympy.expand(index)
-        assert self.simd_nelements
-        assert self.simd_nelements >= 1
-        most_inner_var = self.itervars[-1]
-        replacement = {most_inner_var: most_inner_var * self.simd_nelements}
-        new_index = sympy_subs(expanded_index, replacement)
-        return new_index
-
     def load(self, name: str, index: sympy.Expr):
         var = self.args.input(name)
         index = self.rename_indexing(index)
 
         expanded_index = sympy.expand(index)
-        new_index = self.transform_index(index)
+        new_index = self.scale_index_with_offset(index, self.tiling_factor)
 
         if expanded_index == new_index:
             line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
@@ -868,7 +882,7 @@
         assert mode is None
 
         expanded_index = sympy.expand(index)
-        new_index = self.transform_index(index)
+        new_index = self.scale_index_with_offset(index, self.tiling_factor)
         assert new_index != expanded_index
         line = f"{value}.store({var} + {cexpr(new_index)});"
         self.stores.writeline(name, line)
@@ -934,9 +948,187 @@
         self.cse.store_cache[name] = tmpvar
 
 
+class CppTile2DKernel(CppVecKernel):
+    """
+    A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on
+    the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data
+    tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the
+    tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization
+    logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load
+    and store are generated into kernel.preloads and kernel.poststores buffers.
+
+    The loop structure looks like below:
+    for ...
+      for i_outer ...
+        for ...
+          for inner_most ...
+            // generated by CppTile2DKernel
+            float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads
+            float tmp1[16*16]; // into kernel.preloads
+            for i_inner ... { // the kernel inner loop
+              vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores
+            }
+            at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores
+          for inner_most ... (tail)
+            // generated by CppTile2DTailKernel
+            ...
+      for i_outer ... (tail)
+        for ...
+          for ...
+            // generated by CppKernel
+            ...
+    """
+
+    def __init__(self, args, num_threads, tiling_factor, outer_tiling_idx):
+        super().__init__(args, num_threads, tiling_factor)
+        self.outer_tiling_idx = outer_tiling_idx
+
+    def inner_itervar(self):
+        return sympy.symbols(f"{self.itervars[self.outer_tiling_idx]}_inner")
+
+    def need_vec_transpose(self, index):
+        return self.is_stride1_at(
+            self.itervars[self.outer_tiling_idx], index
+        ) and not self.is_invariant_under(self.itervars[-1], index)
+
+    def gen_transposed_tile_load_store(self, name, var, index, is_store):
+        # transposed tile load/store outside the kernel inner loop
+        factor = self.tiling_factor
+        new_index = self.scale_index_with_offset(index, factor, itervar_idx=-1)
+        new_index = self.scale_index_with_offset(
+            new_index, factor, itervar_idx=self.outer_tiling_idx
+        )
+
+        src = f"{var} + {cexpr(new_index)}"
+        dst = "__place_holder__"
+        ld_src = f"{cexpr(self.stride_at(self.itervars[-1], index))}"
+        ld_dst = f"{factor}"
+        if is_store:
+            src, dst = dst, src
+            ld_src, ld_dst = ld_dst, ld_src
+
+        need_define = True
+        load_or_store = f"at::vec::transpose_mxn<float,{factor},{factor}>({src}, {ld_src}, {dst}, {ld_dst});"
+        if is_store:
+            tile_var = self.cse.newvar()
+        elif load_or_store not in self.cse.cache:
+            tile_var = self.cse.generate(self.preloads, load_or_store, write=False)
+        else:
+            need_define = False
+            tile_var = self.cse.cache[load_or_store]
+
+        if need_define:
+            define_line = f"float {tile_var}[{factor}*{factor}] __attribute__ ((aligned ({factor})));"
+            self.preloads.writeline(define_line)
+
+        load_or_store = load_or_store.replace("__place_holder__", str(tile_var))
+        if is_store:
+            self.poststores.writeline(name, load_or_store)
+        else:
+            self.preloads.writeline(load_or_store)
+
+        return tile_var
+
+    def load(self, name: str, index: sympy.Expr):
+        var = self.args.input(name)
+        index = self.rename_indexing(index)
+
+        inner = self.inner_itervar()
+        expanded_index = sympy.expand(index)
+        if self.need_vec_transpose(expanded_index):
+            tile_var = self.gen_transposed_tile_load_store(
+                name, var, expanded_index, is_store=False
+            )
+            # vector load inside the kernel inner loop
+            line = f"at::vec::Vectorized<float>::loadu({tile_var} + {cexpr(inner * self.tiling_factor)})"
+            return self.cse.generate(self.loads, line)
+        else:
+            new_index = self.scale_index_with_offset(
+                expanded_index,
+                self.tiling_factor,
+                itervar_idx=self.outer_tiling_idx,
+                offset=inner,
+            )
+            return super().load(name, new_index)
+
+    def store(self, name, index, value, mode=None):
+        assert "buf" in name
+        var = self.args.output(name)
+
+        inner = self.inner_itervar()
+        index = self.rename_indexing(index)
+        assert mode is None
+        # TODO(jgong5): assert the index is an affine expression on the itervars in concern
+        expanded_index = sympy.expand(index)
+        if self.need_vec_transpose(expanded_index):
+            tile_var = self.gen_transposed_tile_load_store(
+                name, var, expanded_index, is_store=True
+            )
+            # vector store inside the kernel inner loop
+            line = f"{value}.store({tile_var} + {cexpr(inner * self.tiling_factor)});"
+            self.stores.writeline(name, line)
+        else:
+            new_index = self.scale_index_with_offset(
+                expanded_index,
+                self.tiling_factor,
+                itervar_idx=self.outer_tiling_idx,
+                offset=inner,
+            )
+            super().store(name, new_index, value, mode)
+
+    def codegen_inner_loops(self, code):
+        inner = self.inner_itervar()
+        code.writeline(
+            f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
+        )
+
+
+class CppTile2DTailKernel(CppKernel):
+    """
+    A scalar kernel that handles the tail of inner-most loop split from a 2d tiling. The tile of the outer
+    loop axis is handled with a kernel inner loop (see method `codegen_inner_loops`).
+    """
+
+    def __init__(self, args, num_threads, tiling_factor, outer_tiling_idx):
+        super().__init__(args, num_threads)
+        self.outer_tiling_idx = outer_tiling_idx
+        self.tiling_factor = tiling_factor
+
+    def inner_itervar(self):
+        return sympy.symbols(f"{self.itervars[self.outer_tiling_idx]}_inner")
+
+    def transform_index(self, index):
+        index = self.rename_indexing(index)
+        expanded_index = sympy.expand(index)
+        new_index = self.scale_index_with_offset(
+            expanded_index,
+            self.tiling_factor,
+            itervar_idx=self.outer_tiling_idx,
+            offset=self.inner_itervar(),
+        )
+        return new_index
+
+    def load(self, name: str, index: sympy.Expr):
+        new_index = self.transform_index(index)
+        return super().load(name, new_index)
+
+    def store(self, name, index, value, mode=None):
+        assert "buf" in name
+        var = self.args.output(name)
+        assert mode is None
+        new_index = self.transform_index(index)
+        super().store(name, new_index, value, mode)
+
+    def codegen_inner_loops(self, code):
+        inner = self.inner_itervar()
+        code.writeline(
+            f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
+        )
+
+
 class CppVecKernelChecker(CppVecKernel):
-    def __init__(self, args, num_threads):
-        super(CppVecKernelChecker, self).__init__(args, num_threads)
+    def __init__(self, args, num_threads, tiling_factor):
+        super(CppVecKernelChecker, self).__init__(args, num_threads, tiling_factor)
 
         # Since this kernel is only for checker but does not genreate any
         # code, so we need to decrease the kernel count.
@@ -954,9 +1146,6 @@
                 self.fast_vec_list.append(k)
         self.exit_stack = contextlib.ExitStack()
 
-    def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
-        return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
-
     def could_vec(self, name: str, index: sympy.Expr):
         assert self.itervars is not None
         # Not a loop
@@ -964,7 +1153,9 @@
             return False
 
         most_inner_var = self.itervars[-1]
-        return self.is_legal_data_access(most_inner_var, index)
+        return self.is_invariant_under(most_inner_var, index) or self.is_stride1_at(
+            most_inner_var, index
+        )
 
     def load(self, name: str, index: sympy.Expr):
         if not V.graph.get_dtype(name) in [
@@ -1081,45 +1272,178 @@
         return self
 
 
+class CppTile2DKernelChecker(CppVecKernelChecker):
+    """
+    Currently, we only address the situations with following constraints.
+    1. There exists one and only one fp32 load/store with outer loop var having contiguous buffer accesses.
+    2. When a load/store doesn't have contiguous access in an outer loop var, the access should be
+       vectorizable from the inner-most dim.
+    3. No reduction.
+    """
+
+    def __init__(self, args, num_threads, tiling_factor):
+        super().__init__(args, num_threads, tiling_factor)
+        self.can_tile2d = True
+        self.outer_tiling_idx = -1
+
+    def check_can_tile2d(self, name: str, index: sympy.Expr):
+        if not self.can_tile2d:
+            return
+        # check contiguity from any of the outer loops
+        has_stride1 = False
+        for loop_idx, itervar in enumerate(self.itervars[:-1]):
+            if self.is_stride1_at(itervar, index):
+                # only support 2d tile now
+                if V.graph.get_dtype(name) not in [torch.float, torch.float32] or (
+                    self.outer_tiling_idx >= 0 and self.outer_tiling_idx != loop_idx
+                ):
+                    self.can_tile2d = False
+                    return
+                else:
+                    self.outer_tiling_idx = loop_idx
+                has_stride1 = True
+        if not has_stride1 and not self.could_vec(name, index):
+            self.can_tile2d = False
+        return self.can_tile2d
+
+    def load(self, name: str, index: sympy.Expr):
+        if not V.graph.get_dtype(name) in [
+            torch.float,
+            torch.float32,
+            torch.bool,
+            torch.uint8,
+        ]:
+            self.can_tile2d = False
+            return self.can_tile2d
+        index = self.rename_indexing(index)
+        return self.check_can_tile2d(name, index)
+
+    def store(self, name, index, value, mode=None):
+        if not V.graph.get_dtype(name) in [
+            torch.float,
+            torch.float32,
+        ]:
+            self.can_tile2d = False
+            return self.can_tile2d
+        index = self.rename_indexing(index)
+        return self.check_can_tile2d(name, index)
+
+    def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
+        self.can_tile2d = False
+        return self.can_tile2d
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        super().__exit__(exc_type, exc_val, exc_tb)
+        if not self.simd_vec or self.outer_tiling_idx < 0:
+            self.can_tile2d = False
+
+
 class CppKernelProxy(CppKernel):
-    def __init__(
-        self,
-        args,
-        num_threads,
-        simd_vec_kernel: CppVecKernel,
-        simd_omp_kernel: CppKernel,
-    ):
-        super(CppKernelProxy, self).__init__(args, num_threads)
-        self.simd_vec_kernel = simd_vec_kernel
-        self.simd_omp_kernel = simd_omp_kernel
-        assert simd_omp_kernel, "Expect cpp scalar kernel always exists"
-        self.call_ranges = simd_omp_kernel.call_ranges
-        self.ranges = simd_omp_kernel.ranges
-        self.itervars = simd_omp_kernel.itervars
-        self.reduction_depth = simd_omp_kernel.reduction_depth
+    def __init__(self, kernel_group):
+        super(CppKernelProxy, self).__init__(
+            kernel_group.args, kernel_group.ws.num_threads
+        )
+        self.kernel_group = kernel_group
+        self.loop_nest = None
+        self.call_ranges = None
         self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
 
-    def codegen_loops(self, code, worksharing):
-        if self.simd_vec_kernel is None or not self.picked_vec_isa:
-            assert self.simd_omp_kernel
-            return self.simd_omp_kernel.codegen_loops(code, worksharing)
+    def codegen_nodes(self, nodes):
+        kernel_group = self.kernel_group
+        _, (group, reduction_group) = max(
+            nodes, key=lambda x: int(x.is_reduction())
+        ).group
 
-        assert self.picked_vec_isa
-        loop_nest = LoopNestWithSplit.build(self.simd_omp_kernel)
-        main_loop, tail_loop = loop_nest.split_with_tiling(
-            len(self.simd_vec_kernel.itervars) - 1, self.simd_vec_kernel.simd_nelements
-        )
-        main_loop.set_kernel(self.simd_vec_kernel)
-        tail_loop.set_kernel(self.simd_omp_kernel)
-        main_loop.simd_vec = True
-        tail_loop.simd_omp = True
-        # We chope the loop into two cubes by the nelements - main loop and tail loop.
-        # Regarding the main loop, it is straightforward that it could be vectorized with
-        # nelements. But for the tail loop, it still could be vectorized. For example,
-        # if the nelements is 8(256bits), then the tail loop still could be vectorized
-        # as 4(128bits).
-        tail_loop.simd_nelements = self.simd_vec_kernel.simd_nelements // 2
-        self.codegen_loops_impl(loop_nest, code, worksharing)
+        def codegen_kernel(cls, *args):
+            with kernel_group.new_kernel(cls, *args) as kernel:
+                run(kernel)
+
+                # Ugly hack to maitain the metrics kernel count since
+                # we only count in CppKernelProxy, not those contained in it
+                metrics.generated_kernel_count -= 1
+
+                return kernel
+
+        def run(kernel):
+            vars, reduction_vars = kernel.set_ranges(group, reduction_group)
+            in_suffix = False
+            for node in nodes:
+                if node.group[1] in [
+                    (group, reduction_group),
+                    (group + reduction_group, ()),
+                ]:
+                    assert not in_suffix
+                    node.run(vars, reduction_vars)
+                else:
+                    in_suffix = True
+                    assert node.group[1] == (
+                        group,
+                        (),
+                    ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
+                    # we can fuse in some extra pointwise into the suffix
+                    with kernel.write_to_suffix():
+                        node.run(vars, ())
+
+        scalar_kernel = codegen_kernel(CppKernel)
+        inner_most_idx = len(scalar_kernel.itervars) - 1
+        self.call_ranges = scalar_kernel.call_ranges
+        self.loop_nest = LoopNestWithSplit.build(scalar_kernel)
+
+        if not self.picked_vec_isa:
+            return
+
+        # TODO(jgong5): support alternative tiling factors and data types
+        tiling_factor = self.picked_vec_isa.nelements(dtype=torch.float)
+
+        # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
+        # But the generated scalar kernel has updated these global contexts. Hence, the other kernels
+        # should not do this again to avoid context conflict. By now, we only control the
+        # config.inplace_buffers. In the future, we could maintain more contexts.
+        with patch.object(torch._inductor.config, "inplace_buffers", False):
+
+            with CppVecKernelChecker(
+                deepcopy(self.kernel_group.args), parallel_num_threads(), tiling_factor
+            ) as vec_checker:
+                run(vec_checker)
+
+            with CppTile2DKernelChecker(
+                deepcopy(self.kernel_group.args), parallel_num_threads(), tiling_factor
+            ) as tile2d_checker:
+                run(tile2d_checker)
+
+            if vec_checker.simd_vec:
+                main_loop, tail_loop = self.loop_nest.split_with_tiling(
+                    inner_most_idx, factor=tiling_factor
+                )
+                main_loop.set_kernel(codegen_kernel(CppVecKernel, tiling_factor))
+                tail_loop.set_kernel(scalar_kernel)
+                main_loop.simd_vec = True
+                tail_loop.simd_omp = True
+                # We chop the loop into two cubes by the nelements - main loop and tail loop.
+                # Regarding the main loop, it is straightforward that it could be vectorized with
+                # nelements. But for the tail loop, it still could be vectorized. For example,
+                # if the nelements is 8(256bits), then the tail loop still could be vectorized
+                # as 4(128bits).
+                tail_loop.simd_nelements = tiling_factor // 2
+            elif tile2d_checker.can_tile2d:
+                outer_tiling_idx = tile2d_checker.outer_tiling_idx
+                assert outer_tiling_idx < inner_most_idx
+                outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling(
+                    outer_tiling_idx, factor=tiling_factor
+                )
+                outer_tail_loop.set_kernel(scalar_kernel)
+                inner_main_loop, inner_tail_loop = outer_main_loop.split_with_tiling(
+                    inner_most_idx - outer_tiling_idx, factor=tiling_factor
+                )
+                inner_main_loop.set_kernel(
+                    codegen_kernel(CppTile2DKernel, tiling_factor, outer_tiling_idx)
+                )
+                inner_tail_loop.set_kernel(
+                    codegen_kernel(CppTile2DTailKernel, tiling_factor, outer_tiling_idx)
+                )
+
+    def codegen_loops(self, code, worksharing):
+        self.codegen_loops_impl(self.loop_nest, code, worksharing)
 
 
 class CppScheduling:
@@ -1153,111 +1477,14 @@
     def can_fuse_vertical(cls, node1, node2):
         return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction()
 
-    def can_vec(self, nodes):
-        if not codecache.pick_vec_isa():
-            return False
-
-        _, (group, reduction_group) = max(
-            nodes, key=lambda x: int(x.is_reduction())
-        ).group
-
-        with CppVecKernelChecker(
-            deepcopy(self.kernel_group.args), parallel_num_threads()
-        ) as kernel_checker:
-            vars, reduction_vars = kernel_checker.set_ranges(group, reduction_group)
-            for node in nodes:
-                if node.group[1] in [
-                    (group, reduction_group),
-                    (group + reduction_group, ()),
-                ]:
-                    node.run(vars, reduction_vars)
-                else:
-                    assert node.group[1] == (
-                        group,
-                        (),
-                    ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
-                    node.run(vars, ())
-
-            return kernel_checker.simd_vec
-
-    def _codegen_nodes_impl(self, nodes, is_simd_vec=False):
-        """
-        Turn an set of pre-fused nodes into a C++ kernel.
-        """
-        kernel_group = self.kernel_group
-        _, (group, reduction_group) = max(
-            nodes, key=lambda x: int(x.is_reduction())
-        ).group
-
-        def create_kernel(_is_simd_vec):
-            in_suffix = False
-
-            with kernel_group.new_kernel(_is_simd_vec) as kernel:
-                vars, reduction_vars = kernel.set_ranges(group, reduction_group)
-
-                for node in nodes:
-                    if node.group[1] in [
-                        (group, reduction_group),
-                        (group + reduction_group, ()),
-                    ]:
-                        assert not in_suffix
-                        node.run(vars, reduction_vars)
-                    else:
-                        in_suffix = True
-                        assert node.group[1] == (
-                            group,
-                            (),
-                        ), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
-                        # we can fuse in some extra pointwise into the suffix
-                        with kernel.write_to_suffix():
-                            node.run(vars, ())
-                return kernel
-
-        org_inplace_buffers_flag = config.inplace_buffers
-        if is_simd_vec:
-            # Create vectorization kernel
-            cpp_vec_kernel = create_kernel(True)
-
-            # Since a kernel is divided into two parts - vectorization and non-vectorization.
-            # And the two parts share the same global contexts like V.graph.wrapper_code,
-            # V.kernel.args. But the vectorization kernel generation has updated these global
-            # contexts. Hence, the non-vectorization kernel should not do this again to avoid
-            # conext conflict. By now, we only control the config.inplace_buffers. In the future,
-            # we could maintain more contexts.
-            config.inplace_buffers = False
-
-            # Create non-vectorization kernel
-            cpp_kernel = create_kernel(False)
-
-            # Restore the inplace_buffers flag
-            config.inplace_buffers = org_inplace_buffers_flag
-            return (cpp_vec_kernel, cpp_kernel)
-        else:
-            return (None, create_kernel(False))
-
     def codegen_nodes(self, nodes):
         """
         Turn an set of pre-fused nodes into a C++ kernel.
         """
         kernel_group = self.kernel_group
 
-        can_be_simd_vec = self.can_vec(nodes)
-        simd_vec_kernel, simd_omp_kernel = self._codegen_nodes_impl(
-            nodes, can_be_simd_vec
-        )
-
-        assert simd_omp_kernel
-        metrics.generated_kernel_count -= 1
-        # Maitain the metrics kernel count
-        if simd_vec_kernel:
-            metrics.generated_kernel_count -= 1
-
-        cpp_kernel_proxy = CppKernelProxy(
-            kernel_group.args,
-            kernel_group.ws.num_threads,
-            simd_vec_kernel,
-            simd_omp_kernel,
-        )
+        cpp_kernel_proxy = CppKernelProxy(kernel_group)
+        cpp_kernel_proxy.codegen_nodes(nodes)
 
         kernel_group.finalize_kernel(cpp_kernel_proxy, None)
 
@@ -1279,11 +1506,8 @@
         self.stack.enter_context(self.ws)
         self.count = 0
 
-    def new_kernel(self, simd_vec=False):
-        if simd_vec:
-            return CppVecKernel(self.args, parallel_num_threads())
-        else:
-            return CppKernel(self.args, parallel_num_threads())
+    def new_kernel(self, cls, *args):
+        return cls(self.args, parallel_num_threads(), *args)
 
     def finalize_kernel(self, new_kernel, scheduler):
         self.count += 1