Revert "Quant: add weight int4pack mm kernel (#110914)"

This reverts commit 9980876cab9dcedce7d7dd1c8a2e168b548eaa36.

Reverted https://github.com/pytorch/pytorch/pull/110914 on behalf of https://github.com/jeanschmidt due to Breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/110914#issuecomment-1765302621))
diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu
deleted file mode 100644
index 8f64e14..0000000
--- a/aten/src/ATen/native/cuda/int4mm.cu
+++ /dev/null
@@ -1,1020 +0,0 @@
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
-#include <cuda_runtime.h>
-#include <mma.h>
-#endif
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <ATen/DeviceGuard.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <torch/types.h>
-
-
-namespace at::native {
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
-  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
-  return (a / b);
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
-  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
-  return (a + b - 1) / b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
-  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
-  return divDown(a, b) * b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
-  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
-  return divUp(a, b) * b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ bool isEvenDivisor(U a, V b) {
-  static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
-  return (a % V(b) == 0) && ((a / V(b)) >= 1);
-}
-
-template <class T>
-constexpr __host__ __device__ T pow(T n, int power) {
-  return (power > 0 ? n * pow(n, power - 1) : 1);
-}
-
-template <class T>
-constexpr __host__ __device__ T pow2(int power) {
-  return pow(2, power);
-}
-
-static_assert(pow2<int>(8) == 256, "pow2");
-
-template <typename T>
-constexpr __host__ __device__ int log2(T n, int p = 0) {
-  return (n <= 1) ? p : log2(n / 2, p + 1);
-}
-
-static_assert(log2(2) == 1, "log2");
-static_assert(log2(3) == 1, "log2");
-static_assert(log2(4) == 2, "log2");
-
-template <typename T>
-constexpr __host__ __device__ bool isPowerOf2(T v) {
-  static_assert(std::is_integral<T>::value, "");
-  return (v && !(v & (v - 1)));
-}
-
-static_assert(isPowerOf2(2048), "isPowerOf2");
-static_assert(!isPowerOf2(3333), "isPowerOf2");
-
-template <typename T>
-constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
-  static_assert(std::is_integral<T>::value, "");
-  return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1)));
-}
-
-static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
-
-static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
-
-static_assert(
-    nextHighestPowerOf2(1536000000u) == 2147483648u,
-    "nextHighestPowerOf2");
-static_assert(
-    nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
-    "nextHighestPowerOf2");
-
-template <typename T>
-constexpr __host__ __device__ T nextLowestPowerOf2(T v) {
-  static_assert(std::is_integral<T>::value, "");
-  return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v))));
-}
-
-static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2");
-
-static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2");
-
-inline __host__ __device__ bool isPointerAligned(const void* p, int align) {
-  return reinterpret_cast<uintptr_t>(p) % align == 0;
-}
-
-// Returns the increment needed to aligned the pointer to the next highest
-// aligned address
-template <int Align>
-inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
-  static_assert(isPowerOf2(Align), "");
-  uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1));
-  return diff == 0 ? 0 : uint32_t(Align) - diff;
-}
-
-constexpr int32_t kWarpSize = 32;
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-// f16 vector types
-struct __align__(2) f16x1 {
-  __half vals[1];
-};
-
-struct __align__(4) f16x2 {
-  __half vals[2];
-};
-
-struct __align__(8) f16x4 {
-  __half vals[4];
-};
-
-struct __align__(16) f16x8 {
-  __half vals[8];
-};
-
-// bf16 vector types
-struct __align__(2) bf16x1 {
-  __nv_bfloat16 vals[1];
-};
-
-struct __align__(4) bf16x2 {
-  __nv_bfloat16 vals[2];
-};
-
-struct __align__(8) bf16x4 {
-  __nv_bfloat16 vals[4];
-};
-
-struct __align__(16) bf16x8 {
-  __nv_bfloat16 vals[8];
-};
-
-// bf162 vector types
-struct __align__(4) bf16x2x1 {
-  __nv_bfloat162 vals[1];
-};
-
-struct __align__(8) bf16x2x2 {
-  __nv_bfloat162 vals[2];
-};
-
-struct __align__(16) bf16x2x4 {
-  __nv_bfloat162 vals[4];
-};
-
-struct __align__(16) bf16x2x4_u32 {
-  uint32_t vals[4];
-};
-
-struct __align__(8) bf16x2x2_u32 {
-  uint32_t vals[2];
-};
-
-struct __align__(4) bf16x2x1_u32 {
-  uint32_t vals[1];
-};
-
-template <typename T, int N>
-struct __align__(sizeof(T) * N) VectorType {
-  T vals[N];
-};
-
-// from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
-  bf16x2x4 result;
-  constexpr int kElements = 8;
-
-  uint32_t* h = reinterpret_cast<uint32_t*>(&result);
-  uint32_t const source_i4s = source;
-
-  // First, we extract the i4s and construct an intermediate fp16 number.
-  static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
-  static constexpr uint32_t MASK = 0x000f000f;
-  static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
-
-  // We don't have enough mantissa to remove as much shift overhead as FP16, so
-  // we must loop. No shift needed for first item.
-  uint32_t i4s = source_i4s;
-  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
-               : "=r"(h[0])
-               : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
-#pragma unroll
-  for (int ii = 1; ii < kElements / 2; ++ii) {
-    i4s >>= 4; // or is it 8?
-    // (i4s & 0x000f000f) | 0x43004300
-    asm volatile(
-        "lop3.b32 %0, %1, %2, %3, %4;\n"
-        : "=r"(h[ii])
-        : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
-  }
-
-  // This is the BF16 {-136, -136} represented as an integer.
-  static constexpr uint32_t BF16_BIAS = 0xC308C308;
-  static constexpr uint32_t BF16_ONE = 0x3F803F80;
-
-// Finally, we construct the output numbers.
-#pragma unroll
-  for (int ii = 0; ii < kElements / 2; ++ii) {
-    // Since this section is for Ampere+, we use bf16 fma to do the bias
-    // subtraction
-    asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
-        : "=r"(h[ii])
-        : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
-  }
-
-  return result;
-}
-
-
-
-enum class KReductionType {
-  // No k-reduction is needed between blocks as the number of k-tiles processed
-  // per block are exact and we can directly write the output
-  None,
-};
-
-// Loads the A matrix in 16-bit standard m x k row major layout, and writes
-// the C matrix in 16-bit standard m x n row major layout:
-//
-// size [m][k]
-template <int KTilesPerWarp, KReductionType ReduceType = KReductionType::None>
-struct ALayout_RM {
-  static constexpr int32_t kMTileSize = 16;
-  static constexpr int32_t kNTileSize = 8;
-  static constexpr int32_t kKTileSize = 16;
-
-  static __device__ void load(
-      const void* A,
-      int32_t m,
-      int32_t k,
-      int32_t mTiles,
-      int32_t mTile,
-      int32_t kTiles,
-      int32_t kTileStart,
-      int32_t laneId,
-      bf16x2x4_u32 out[KTilesPerWarp]) {
-    auto mLane = mTile * kMTileSize + (laneId / 4);
-    auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
-
-    // access
-    // [mTile * kMTileSize + (laneId / 4)]
-    // [kTileStart * kKTileSize + (laneId % 4) * 2]
-    auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
-
-    auto aPtrPlus8Rows = aPtr + 8 * k;
-
-    bool m0InBounds = mLane < m;
-    bool m1InBounds = (mLane + 8) < m;
-
-#pragma unroll
-    for (int i = 0; i < KTilesPerWarp; ++i) {
-      out[i].vals[0] = m0InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
-          : uint32_t(0);
-      out[i].vals[1] = m1InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
-          : uint32_t(0);
-
-      out[i].vals[2] = m0InBounds
-          ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 8)
-          : uint32_t(0);
-      out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
-                                        aPtrPlus8Rows + i * kKTileSize + 8)
-                                  : uint32_t(0);
-    }
-  }
-
-  static __device__ void store(
-      void* C,
-      int32_t m,
-      int32_t n,
-      int32_t mOutTiles,
-      int32_t mTile,
-      int32_t nOutTiles,
-      int32_t nTile,
-      int32_t laneId,
-      const float4& out) {
-    static_assert(ReduceType == KReductionType::None, "");
-
-    if constexpr (ReduceType == KReductionType::None) {
-      // sum.x / sum.y are written at
-      // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
-      // sum.z / sum.w are written at
-      // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
-      // i.e., same columns, different row.
-      int outRow = mTile * kMTileSize + (laneId / 4);
-      int outCol = nTile * kNTileSize + (laneId % 4) * 2;
-
-      // Pointer where sum.x / sum.y is written
-      auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
-
-      auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
-      auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
-
-      if (outRow < m) {
-        *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01;
-      }
-
-      // sum.z, sum.w at +8 rows from cPtr
-      if (outRow + 8 < m) {
-        *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
-      }
-    }
-  }
-};
-
-template <int KTilesPerWarp, int InnerKTiles, int QGroupSize>
-struct BLayout_TC_int4 {
-  static constexpr int32_t kInnerKTiles = InnerKTiles;
-  static constexpr int32_t kMTileSize = 16;
-  static constexpr int32_t kNTileSize = 8;
-  static constexpr int32_t kKTileSize = 16;
-
-  static __device__ void load(
-      // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
-      // n / 8: n-tiles (n8)
-      // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16)
-      // 32: value per warp lane
-      // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
-      // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest value)
-      // 4 k-tiles packed is a uint32x2 (64 bits)
-      // 8 k-tiles packed is a uint32x4 (128 bits)
-      const void* __restrict__ B,
-      // size [k / qGroupSize][n][2]
-      // Contains the scale and zero point of each of the quantized int4 values
-      // within B
-      // v_reconstructed = (bf16(B_int4_val) * scale) - zero
-      const void* __restrict__ quantizationInfo,
-      int32_t n,
-      int32_t k,
-      int32_t nTiles,
-      int32_t nTile,
-      int32_t kTiles,
-      int32_t kTileStart,
-      int32_t laneId,
-      bf16x2x4_u32 out[KTilesPerWarp / InnerKTiles][InnerKTiles / 2]) {
-    // offset [nTile][kTileStart / InnerKTiles][laneId][0]
-    auto bPtr = reinterpret_cast<const int32_t*>(B) +
-        (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) *
-          kWarpSize) +
-         laneId) *
-            (InnerKTiles / 2);
-
-    int32_t b_int4[KTilesPerWarp / InnerKTiles][InnerKTiles / 2];
-
-#pragma unroll
-    for (int i = 0; i < KTilesPerWarp / InnerKTiles; ++i) {
-      auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2);
-
-      if constexpr (InnerKTiles == 2) {
-        b_int4[i][0] = bPtrCur[0];
-      }
-
-      if constexpr (InnerKTiles == 4) {
-
-        // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
-        //              : "=r"(b_int4[i][0]), "=r"(b_int4[i][1])
-        //              : "l"(bPtrCur));
-
-        int2 load8 = reinterpret_cast<const int2*>(bPtrCur)[0];
-        b_int4[i][0] = load8.x;
-        b_int4[i][1] = load8.y;
-      }
-
-      if constexpr (InnerKTiles == 8) {
-
-        // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
-        //              : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), "=r"(b_int4[i][2]), "=r"(b_int4[i][3])
-        //              : "l"(bPtrCur));
-
-        int4 load16 = reinterpret_cast<const int4*>(bPtrCur)[0];
-        b_int4[i][0] = load16.x;
-        b_int4[i][1] = load16.y;
-        b_int4[i][2] = load16.z;
-        b_int4[i][3] = load16.w;
-      }
-    }
-
-    // Load needed info for dequantization
-
-    static_assert(isPowerOf2(QGroupSize), "");
-    static_assert(isEvenDivisor(QGroupSize, kKTileSize), "");
-    // smallest quantization group size is 32 (2 k-tiles are packed in an int32)
-    static_assert(QGroupSize >= kKTileSize * 2, "");
-    constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize);
-    // a q-group could be larger than what we are handling in a single warp
-    constexpr int kNumQGroups = (KTilesPerWarp / kKTilesPerQGroup) < 1
-        ? 1
-        : (KTilesPerWarp / kKTilesPerQGroup);
-
-    __nv_bfloat162 qScaleAndZero[kNumQGroups];
-    {
-      int32_t laneN = nTile * kNTileSize + (laneId / 4);
-      int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
-
-      int32_t n = nTiles * kNTileSize;
-
-      // offset [qScale_kGroup][qScale_n][0]
-      auto qInfoPtr = reinterpret_cast<const __nv_bfloat16*>(quantizationInfo) +
-          (groupStart * n + laneN) * 2;
-
-#pragma unroll
-      for (int i = 0; i < kNumQGroups; ++i) {
-        qScaleAndZero[i] =
-            *reinterpret_cast<const __nv_bfloat162*>(qInfoPtr + i * n * 2);
-      }
-    }
-
-    //
-    // De-quantize int4 values to bf16. Values are dequantized as truly int4
-    // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
-    //
-    {
-      // FIXME: does this negatively affect register counts, or will nvcc
-      // move this expansion (and data loads above) closer to the point of use?
-      __nv_bfloat162 qScale[kNumQGroups];
-      __nv_bfloat162 qZero[kNumQGroups];
-
-#pragma unroll
-      for (int i = 0; i < kNumQGroups; ++i) {
-        qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x);
-        qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y);
-      }
-
-#pragma unroll
-      for (int i = 0; i < KTilesPerWarp / InnerKTiles; ++i) {
-#pragma unroll
-        for (int j = 0; j < InnerKTiles / 2; ++j) {
-          bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]);
-
-          int curKTile = i * InnerKTiles + j * 2;
-          int curQGroup = (curKTile * kKTileSize) / QGroupSize;
-
-          // The dequantized values in `v` for a given lane have the same n
-          // dimension (the B tensor core layout has all values in the same
-          // thread along the same n) but different k dimension, but all are
-          // guaranteed to occur within the same quantization group, so we need
-          // only load a single scale + zero to cover what this lane has
-#pragma unroll
-          for (int k = 0; k < 4; ++k) {
-            v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]);
-          }
-
-          // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and can't
-          // be used as a 32-bit asm register argument for `mma`
-          static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
-          std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
-        }
-      }
-    }
-  }
-};
-
-template <
-    typename ALayout,
-    typename BLayout,
-    typename CLayout,
-    int Warps,
-    int KTilesPerWarp>
-__global__ __launch_bounds__(Warps * kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
-    // Data for the A matrix, loaded as per ALayout
-    const void* __restrict__ A,
-
-    // Data for the B matrix, loaded as per BLayout
-    const void* __restrict__ B,
-
-    // Optional quantization data for dequantizing B, loaded as per BLayout
-    const void* __restrict__ B_quantizationInfo,
-
-    // Output data for the C matrix, stored as per CLayout
-    void* __restrict__ C,
-
-    // The size of the matrix multiplication
-    int32_t m,
-    int32_t n,
-    int32_t k,
-
-    // The size of the matrix multiplication, in multiples of our TC tile size
-    int32_t mTiles,
-    int32_t nTiles,
-    int32_t kTiles) {
-  constexpr int32_t kMTileSize = 16;
-  constexpr int32_t kNTileSize = 8;
-  constexpr int32_t kKTileSize = 16;
-
-  static_assert(
-      ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
-          ALayout::kKTileSize == kKTileSize,
-      "");
-
-  static_assert(
-      BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize &&
-          BLayout::kKTileSize == kKTileSize,
-      "");
-
-  static_assert(
-      CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize &&
-          CLayout::kKTileSize == kKTileSize,
-      "");
-
-  constexpr int kInnerKTiles = BLayout::kInnerKTiles;
-
-  // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads
-  static_assert(
-      kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, "");
-
-  // We always process at least kInnerKTiles k-tiles back to back in a warp
-  static_assert(KTilesPerWarp >= kInnerKTiles && isEvenDivisor(KTilesPerWarp, kInnerKTiles), "");
-
-  auto warpId = threadIdx.y;
-  auto laneId = threadIdx.x;
-
-  int32_t mTile = blockIdx.z;
-  int32_t nTile = blockIdx.y;
-
-  float4 c{0.0f, 0.0f, 0.0f, 0.0f};
-
-  // Requirement: kTiles must be an even multiple of Warps * KTilesPerWarp
-  for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerWarp;
-       kTileBase < kTiles;
-       kTileBase += Warps * KTilesPerWarp) {
-    //
-    // Load data from A
-    //
-    bf16x2x4_u32 a[KTilesPerWarp];
-    ALayout::load(A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
-
-    //
-    // Load data from B and de-quantize as needed
-    // Each k-tile is bf16x2x2
-    //
-    bf16x2x4_u32 b[KTilesPerWarp / kInnerKTiles][kInnerKTiles / 2];
-    BLayout::load(
-      B,
-      B_quantizationInfo,
-      n,
-      k,
-      nTiles,
-      nTile,
-      kTiles,
-      kTileBase,
-      laneId,
-      b);
-
-    //
-    // Now, perform the matrix multiplication
-    //
-
-    // We accumulate across k-tiles here
-#pragma unroll
-    for (int i = 0; i < KTilesPerWarp / kInnerKTiles; ++i) {
-      static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, "");
-#pragma unroll
-      for (int j = 0; j < kInnerKTiles / 2; ++j) {
-        // We don't simply accumulate into `c` as this creates a too-strong
-        // execution dependency. Instead, we only periodically accumulate into
-        // `c`
-        float4 cTmp[2];
-
-#pragma unroll
-        for (int k = 0; k < 2; ++k) {
-          cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
-        }
-
-#pragma unroll
-        for (int k = 0; k < 2; ++k) {
-          asm volatile(
-            "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
-            "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
-            : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
-            : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]),
-              "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]),
-              "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]),
-              "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]),
-              "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
-              "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
-              "f"(cTmp[k].x),
-              "f"(cTmp[k].y),
-              "f"(cTmp[k].z),
-              "f"(cTmp[k].w));
-        }
-
-#pragma unroll
-        for (int k = 0; k < 2; ++k) {
-          c.x += cTmp[k].x;
-          c.y += cTmp[k].y;
-          c.z += cTmp[k].z;
-          c.w += cTmp[k].w;
-        }
-      }
-    }
-  } // for all kTiles
-
-  //
-  // Reduce independent k-tiles (same m/n) across warps
-  //
-  __shared__ float4 smem_sum[Warps][kWarpSize];
-
-  smem_sum[warpId][laneId] = c;
-
-  __syncthreads();
-
-  if (warpId == 0) {
-    float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
-
-    // Reduce across the block in the first warp
-    for (int i = 0; i < Warps; ++i) {
-      float4 v = smem_sum[i][laneId];
-      sum_f32.x += v.x;
-      sum_f32.y += v.y;
-      sum_f32.z += v.z;
-      sum_f32.w += v.w;
-    }
-
-    // Write the reduced result (in the first warp) into the output
-    CLayout::store(
-        C,
-        m,
-        n,
-        mTiles,
-        mTile,
-        // n for C output becomes k for A input, so for m16n8k16,
-        // we need to halve the tiles
-        nTiles / 2,
-        nTile,
-        laneId,
-        sum_f32);
-  }
-}
-
-
-template <
-    typename ALayout,
-    typename BLayout,
-    typename CLayout,
-    int Warps,
-    int KTilesPerWarp>
-void launch_tinygemm_kernel(
-    const torch::Tensor& A,
-    const torch::Tensor& B,
-    const torch::Tensor* qScaleAndZeros, /* optional */
-    torch::Tensor& C_final,
-    int32_t mTiles,
-    int32_t nTiles,
-    int32_t kTiles,
-    int32_t m,
-    int32_t n,
-    int32_t k,
-    cudaStream_t stream) {
-  TORCH_CHECK(
-      kTiles >= (Warps * KTilesPerWarp) &&
-      isEvenDivisor(kTiles, Warps * KTilesPerWarp));
-
-  TORCH_CHECK(
-      KTilesPerWarp >= BLayout::kInnerKTiles && isPowerOf2(KTilesPerWarp));
-
-  // After intra-block reduction across the k dimension, we are left with this
-  // many tiles
-  //  int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp);
-  int32_t postKernelKTiles = 1; // we loop
-
-  auto grid = dim3(postKernelKTiles, nTiles, mTiles);
-  auto block = dim3(kWarpSize, Warps);
-
-  auto func =
-      tinygemm_m16n8k16_chunk_kernel<ALayout, BLayout, CLayout, Warps, KTilesPerWarp>;
-
-  func<<<grid, block, 0, stream>>>(
-      A.data_ptr(),
-      B.data_ptr(),
-      qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr,
-      C_final.data_ptr(),
-      m,
-      n,
-      k,
-      mTiles,
-      nTiles,
-      kTiles);
-
-  cudaFuncAttributes funcAttr;
-  C10_CUDA_CHECK(cudaFuncGetAttributes(
-      &funcAttr,
-      func));
-}
-
-// FIXME: parallelize better, smem staging etc?
-template <int InnerKTiles>
-__global__ void matrix_to_m16n8k16_Bint4_layout(
-    // size [n][k]
-    const at::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits> in,
-    // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2]
-    at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> out) {
-  // int4 values are packed into int32 values, which require at least 8. Given
-  // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of
-  // innermost k-tiles that we can use is 2.
-  static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
-
-  constexpr int32_t kMTileSize = 16;
-  constexpr int32_t kNTileSize = 8;
-  constexpr int32_t kKTileSize = 16;
-
-  // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
-  auto kOuterTile = blockIdx.x;
-  auto nTile = blockIdx.y;
-  auto t = threadIdx.x;
-
-  // Two k-tiles are packed into an int32 at a time
-#pragma unroll
-  for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
-    // n dimension that this lane loads from
-    auto n0 = nTile * kNTileSize + (t / 4);
-
-    bool n0Valid = n0 < in.size(0);
-
-    int32_t ks[8];
-
-    auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize;
-    ks[0] = kBase0 + (t % 4) * 2;
-    ks[1] = ks[0] + 1;
-    ks[2] = ks[0] + 8;
-    ks[3] = ks[0] + 8 + 1;
-
-    auto kBase1 = kBase0 + kKTileSize;
-    ks[4] = kBase1 + (t % 4) * 2;
-    ks[5] = ks[4] + 1;
-    ks[6] = ks[4] + 8;
-    ks[7] = ks[4] + 8 + 1;
-
-    auto pIn = &in[n0][0];
-
-    uint32_t v[8];
-#pragma unroll
-    for (int i = 0; i < 8; ++i) {
-      v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint32_t(0);
-    }
-
-    int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) |
-        (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0];
-
-    // inner k-tiles pack two at a time
-    out[nTile][kOuterTile][t][innerKTile / 2] = pack;
-  }
-}
-
-#endif
-
-
-torch::Tensor _weight_int4pack_mm_cuda(
-    const torch::Tensor& A,
-    const torch::Tensor& B,
-    int64_t qGroupSize,
-    const torch::Tensor& qScaleAndZeros) {
-  c10::cuda::CUDAGuard g(A.device());
-  auto stream = at::cuda::getCurrentCUDAStream();
-
-  TORCH_CHECK(
-      A.device() == B.device() && A.device() == qScaleAndZeros.device());
-
-  constexpr int32_t kMTileSize = 16;
-  constexpr int32_t kNTileSize = 8;
-  constexpr int32_t kKTileSize = 16;
-
-  // row major layout
-  auto m = A.size(0);
-  auto mTiles = divUp(m, kMTileSize);
-
-  // tensor core layout
-  auto nTiles = B.size(0);
-  auto n = nTiles * kNTileSize;
-
-  // row major layout
-  auto k = A.size(1);
-  auto kTiles = divUp(k, kKTileSize);
-
-  // The number of inner k tiles is the innermost dimension of  times 2
-  // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4
-  // packed into 1 int32 for int4 B
-  auto B_innerKTiles = B.size(3) * 2;
-  TORCH_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8);
-
-  // A is standard row major
-  TORCH_CHECK(A.dtype() == torch::kBFloat16);
-  TORCH_CHECK(A.is_contiguous());
-  TORCH_CHECK(A.dim() == 2);
-
-  // B has B_innerKTiles k-tiles in the innermost dimension
-  TORCH_CHECK(B.dtype() == torch::kInt32);
-  TORCH_CHECK(B.is_contiguous());
-  TORCH_CHECK(B.dim() == 4);
-  TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
-  TORCH_CHECK(B.size(2) == kWarpSize);
-
-  // Validate the scale and zero point tensor for dequantization
-  TORCH_CHECK(qScaleAndZeros.dim() == 3);
-  auto numQGroups = qScaleAndZeros.size(0);
-  TORCH_CHECK(k / numQGroups == qGroupSize);
-  TORCH_CHECK(qScaleAndZeros.size(1) == n);
-  TORCH_CHECK(qScaleAndZeros.size(2) == 2);
-
-  // Output is a standard row-major matrix
-  auto C_final = torch::empty(
-      {m, n},
-      torch::TensorOptions().dtype(at::kBFloat16).device(A.device()));
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE)          \
-  do {                                                                        \
-    using ACLayout = ALayout_RM<K_TILES_PER_WARP, REDUCE_TYPE>;               \
-                                                                              \
-    TORCH_CHECK(                                                              \
-        K_TILES_PER_WARP >= B_innerKTiles &&                                  \
-        isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles));                      \
-                                                                              \
-    switch (B_innerKTiles) {                                                  \
-      case 2:                                                                 \
-        if constexpr (K_TILES_PER_WARP >= 2) {                                \
-          using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 2, Q_GROUP_SIZE>; \
-          launch_tinygemm_kernel<                                             \
-              ACLayout,                                                       \
-              BLayout,                                                        \
-              ACLayout,                                                       \
-              WARPS,                                                          \
-              K_TILES_PER_WARP>(                                              \
-              A,                                                              \
-              B,                                                              \
-              &qScaleAndZeros,                                                \
-              C_final,                                                        \
-              mTiles,                                                         \
-              nTiles,                                                         \
-              kTiles,                                                         \
-              m,                                                              \
-              n,                                                              \
-              k,                                                              \
-              stream);                                                        \
-        }                                                                     \
-        break;                                                                \
-      case 4:                                                                 \
-        if constexpr (K_TILES_PER_WARP >= 4) {                                \
-          using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 4, Q_GROUP_SIZE>; \
-          launch_tinygemm_kernel<                                             \
-              ACLayout,                                                       \
-              BLayout,                                                        \
-              ACLayout,                                                       \
-              WARPS,                                                          \
-              K_TILES_PER_WARP>(                                              \
-              A,                                                              \
-              B,                                                              \
-              &qScaleAndZeros,                                                \
-              C_final,                                                        \
-              mTiles,                                                         \
-              nTiles,                                                         \
-              kTiles,                                                         \
-              m,                                                              \
-              n,                                                              \
-              k,                                                              \
-              stream);                                                        \
-        }                                                                     \
-        break;                                                                \
-      case 8:                                                                 \
-        if constexpr (K_TILES_PER_WARP >= 8) {                                \
-          using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 8, Q_GROUP_SIZE>; \
-          launch_tinygemm_kernel<                                             \
-              ACLayout,                                                       \
-              BLayout,                                                        \
-              ACLayout,                                                       \
-              WARPS,                                                          \
-              K_TILES_PER_WARP>(                                              \
-              A,                                                              \
-              B,                                                              \
-              &qScaleAndZeros,                                                \
-              C_final,                                                        \
-              mTiles,                                                         \
-              nTiles,                                                         \
-              kTiles,                                                         \
-              m,                                                              \
-              n,                                                              \
-              k,                                                              \
-              stream);                                                        \
-        }                                                                     \
-        break;                                                                \
-      default:                                                                \
-        break;                                                                \
-    }                                                                         \
-  } while (false)
-
-#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \
-  do {                                                       \
-    switch (qGroupSize) {                                    \
-      case 32:                                               \
-        RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE);  \
-        break;                                               \
-      case 64:                                               \
-        RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE);  \
-        break;                                               \
-      case 128:                                              \
-        RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \
-        break;                                               \
-      case 256:                                              \
-        RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \
-        break;                                               \
-    }                                                        \
-  } while (false)
-
-  // These are the only versions handled at the moment
-  TORCH_CHECK(k == 32 || k == 64 || isEvenDivisor(k, 1024));
-  TORCH_CHECK(
-      qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
-      qGroupSize == 256);
-
-  if (k == 32) {
-    HANDLE_Q_GROUP(1, 2, KReductionType::None);
-  } else if (k == 64) {
-    HANDLE_Q_GROUP(1, 4, KReductionType::None);
-  } else {
-    // Handle k = 1024 (8 * 8 * kKTileSize) at a time
-    HANDLE_Q_GROUP(8, 8, KReductionType::None);
-  }
-
-#undef RUN_GEMM
-
-  return C_final;
-#endif
-  TORCH_CHECK(false, "_weight_int4pack_mm_cuda is not available for build.")
-  return C_final;
-}
-
-// input is [n][k] (int32 dtype)
-// output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
-torch::Tensor _convert_weight_to_int4pack_cuda(
-    const torch::Tensor& in,
-    int64_t innerKTiles) {
-  c10::cuda::CUDAGuard g(in.device());
-  auto stream = at::cuda::getCurrentCUDAStream();
-
-  TORCH_CHECK(in.dim() == 2);
-  TORCH_CHECK(in.dtype() == torch::kInt32);
-  TORCH_CHECK(in.is_contiguous());
-
-  // At least 2 k-tiles need to be packed back to back in the innermost
-  // dimension, as the m16n8k16 tensor core tile presents 4 scalar values for
-  // the B matrix, but the minimum word size for the packed format is 4 bytes
-  // (int32). 4 inner K-tiles = 8 byte load, 8 inner k-tiles = 16 byte load
-  // which is the maximum vectorized load/store size
-  TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
-
-  constexpr int32_t kMTileSize = 16;
-  constexpr int32_t kNTileSize = 8;
-  constexpr int32_t kKTileSize = 16;
-
-  auto nTiles = divUp(in.size(0), kNTileSize);
-
-  // k-tiles are packed back to back in the innermost dimension in order to
-  // allow for 4/8/16 byte loads
-  TORCH_CHECK(isEvenDivisor(in.size(1), innerKTiles * kKTileSize));
-  // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize
-  auto kSuperTiles = divUp(in.size(1), innerKTiles * kKTileSize);
-
-  // each block handles `innerKTiles` k-tiles.
-  // 2 k-tiles are a single int32
-  auto out = torch::empty(
-      {nTiles, kSuperTiles, 32, innerKTiles / 2},
-      torch::TensorOptions().dtype(torch::kInt32).device(in.device()));
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-  dim3 grid(kSuperTiles, nTiles);
-
-  if (innerKTiles == 2) {
-    matrix_to_m16n8k16_Bint4_layout<2><<<grid, kWarpSize, 0, stream>>>(
-        in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
-        out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
-  } else if (innerKTiles == 4) {
-    matrix_to_m16n8k16_Bint4_layout<4><<<grid, kWarpSize, 0, stream>>>(
-        in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
-        out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
-  } else if (innerKTiles == 8) {
-    matrix_to_m16n8k16_Bint4_layout<8><<<grid, kWarpSize, 0, stream>>>(
-        in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
-        out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
-  }
-
-  return out;
-#endif
-  TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is not available for build.")
-  return out;
-}
-
-
-} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 36a2f4e..5ffad03 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4019,14 +4019,6 @@
   dispatch:
     CUDA: _int_mm_out_cuda
 
-- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
-  dispatch:
-    CUDA: _convert_weight_to_int4pack_cuda
-
-- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
-  dispatch:
-    CUDA: _weight_int4pack_mm_cuda
-
 - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
   python_module: sparse
 
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8e84155..d460d19 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -54,7 +54,6 @@
 aten::_convert_indices_from_coo_to_csr.out
 aten::_convert_indices_from_csr_to_coo
 aten::_convert_indices_from_csr_to_coo.out
-aten::_convert_weight_to_int4pack
 aten::_convolution
 aten::_convolution.out
 aten::_copy_from
@@ -588,7 +587,6 @@
 aten::_values
 aten::_values_copy
 aten::_values_copy.out
-aten::_weight_int4pack_mm
 aten::_weight_norm_interface
 aten::_weight_norm_interface.out
 aten::_weight_norm_interface_backward
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 17217d9..a1e119a 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -29,7 +29,7 @@
     all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
     floating_and_complex_types_and, floating_types_and, complex_types,
 )
-from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, tf32_on_and_off, _get_magma_version, \
+from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, _get_magma_version, \
     _get_torch_cuda_version
 from torch.distributions.binomial import Binomial
 import torch.backends.opt_einsum as opt_einsum
@@ -5771,123 +5771,6 @@
                                r"Expected result.size\(0\) to be 17 but got 16",
                                lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
 
-    def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16):
-        assert w.dim() == 2
-        w = w.transpose(0, 1).contiguous()
-        assert q_group_size > 1
-        assert w.shape[-1] % q_group_size == 0
-
-        to_quant = w.reshape(-1, q_group_size)
-        assert torch.isnan(to_quant).sum() == 0
-
-        max_val = to_quant.amax(dim=1, keepdim=True)
-        min_val = to_quant.amin(dim=1, keepdim=True)
-        max_int = 2 ** n_bit - 1
-        min_int = 0
-        scales = (max_val - min_val).clamp(min=1e-6) / max_int
-        assert torch.isnan(scales).sum() == 0
-
-        zeros = min_val + scales * (2 ** (n_bit - 1))
-        assert torch.isnan(zeros).sum() == 0
-
-        out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
-        assert torch.isnan(out).sum() == 0
-
-        out = out.to(dtype=torch.int32).reshape(w.shape)
-
-        # Scales and zeros for the same q-group should be contiguous, so we can
-        # load as a 32-bit word
-        scales = scales.view(w.shape[0], -1)
-        zeros = zeros.view(w.shape[0], -1)
-        scales_and_zeros = (
-            torch.cat(
-                [
-                    scales.reshape(scales.size(0), scales.size(1), 1),
-                    zeros.reshape(zeros.size(0), zeros.size(1), 1),
-                ],
-                2,
-            ).transpose(0, 1).contiguous()
-        )
-
-        return out, scales_and_zeros
-
-    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
-    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
-    @unittest.skipIf(not SM80OrLater, "need sm_80")
-    @onlyCUDA
-    @parametrize("m", [32, 64])
-    @parametrize("k", [32, 64])
-    @parametrize("n", [48, 64])
-    def test__int4_mm(self, device, m, k, n):
-        if TEST_WITH_ROCM:
-            self.skipTest("_int4_mm not compiled for ROCM")
-
-        q_group = 32
-        inner_k_tiles = 2
-
-        torch.manual_seed(1)
-        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
-        b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
-
-        def convert_weight_to_int4pack(b):
-            b_int32, b_scales_and_zeros = self._group_quantize_tensor(
-                b, n_bit=4, q_group_size=q_group
-            )
-            b_int4pack = torch._convert_weight_to_int4pack(
-                b_int32, inner_k_tiles
-            )
-
-            return b_int4pack, b_scales_and_zeros
-
-        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
-            return torch._weight_int4pack_mm(
-                a, b_int4pack, q_group, b_scales_and_zeros
-            )
-
-        b_int4pack, b_scales_and_zeros = convert_weight_to_int4pack(b)
-        res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
-        ref = torch.mm(a, b)
-
-        mean_err = ((res - ref).abs() / ref).mean()
-        self.assertTrue(mean_err < 0.05)
-
-    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
-    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
-    @unittest.skipIf(not SM80OrLater, "need sm_80")
-    @onlyCUDA
-    @parametrize("m", [32, 64])
-    @parametrize("k", [32, 64])
-    @parametrize("n", [48, 64])
-    def test_compile_int4_mm(self, device, m, k, n):
-        if TEST_WITH_ROCM:
-            self.skipTest("_int4_mm not compiled for ROCM")
-
-        q_group = 32
-        inner_k_tiles = 2
-
-        torch.manual_seed(1)
-        a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
-        b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
-
-        b_int32, b_scales_and_zeros = self._group_quantize_tensor(
-            b, n_bit=4, q_group_size=q_group
-        )
-
-        @torch.compile
-        def int4_mm(a, b_int32, b_scales_and_zeros):
-            b_int4pack = torch._convert_weight_to_int4pack(
-                b_int32, inner_k_tiles
-            )
-            return torch._weight_int4pack_mm(
-                a, b_int4pack, q_group, b_scales_and_zeros
-            )
-
-        res = int4_mm(a, b_int32, b_scales_and_zeros)
-        ref = torch.mm(a, b)
-
-        mean_err = ((res - ref).abs() / ref).mean()
-        self.assertTrue(mean_err < 0.05)
-
     @slowTest
     @onlyNativeDeviceTypes
     # bfloat16 doesn't have sufficient precision to pass this test
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 3afcc2e..98097f5 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -3216,41 +3216,6 @@
     return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
 
 
-@register_meta([aten._convert_weight_to_int4pack])
-def meta__convert_weight_to_int4pack(w, inner_k_tiles):
-    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
-    torch._check(
-        w.dtype is torch.int32,
-        lambda: f"expected w to be int32, got {w.dtype}",
-    )
-    n = w.size(0)
-    k = w.size(1)
-    return w.new_empty(
-        (
-            n // 8,
-            k // (inner_k_tiles * 16),
-            32,
-            inner_k_tiles // 2,
-        ),
-        dtype=torch.int32,
-    )
-
-
-@register_meta([aten._weight_int4pack_mm])
-def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
-    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
-    torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
-    torch._check(
-        x.dtype is torch.bfloat16,
-        lambda: f"expected x to be bf16, got {x.dtype}",
-    )
-    torch._check(
-        w.dtype is torch.int32,
-        lambda: f"expected w to be int32, got {w.dtype}",
-    )
-    return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
-
-
 @register_meta(aten._cdist_forward.default)
 def meta_cdist_forward(x1, x2, p, compute_mode):
     torch._check(