Revert "Quant: add weight int4pack mm kernel (#110914)"
This reverts commit 9980876cab9dcedce7d7dd1c8a2e168b548eaa36.
Reverted https://github.com/pytorch/pytorch/pull/110914 on behalf of https://github.com/jeanschmidt due to Breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/110914#issuecomment-1765302621))
diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu
deleted file mode 100644
index 8f64e14..0000000
--- a/aten/src/ATen/native/cuda/int4mm.cu
+++ /dev/null
@@ -1,1020 +0,0 @@
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
-#include <cuda_runtime.h>
-#include <mma.h>
-#endif
-#include <ATen/ATen.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <ATen/DeviceGuard.h>
-#include <c10/cuda/CUDAGuard.h>
-#include <torch/types.h>
-
-
-namespace at::native {
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
- static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
- return (a / b);
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
- static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
- return (a + b - 1) / b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
- static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
- return divDown(a, b) * b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
- static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
- return divUp(a, b) * b;
-}
-
-template <typename U, typename V>
-constexpr __host__ __device__ bool isEvenDivisor(U a, V b) {
- static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
- return (a % V(b) == 0) && ((a / V(b)) >= 1);
-}
-
-template <class T>
-constexpr __host__ __device__ T pow(T n, int power) {
- return (power > 0 ? n * pow(n, power - 1) : 1);
-}
-
-template <class T>
-constexpr __host__ __device__ T pow2(int power) {
- return pow(2, power);
-}
-
-static_assert(pow2<int>(8) == 256, "pow2");
-
-template <typename T>
-constexpr __host__ __device__ int log2(T n, int p = 0) {
- return (n <= 1) ? p : log2(n / 2, p + 1);
-}
-
-static_assert(log2(2) == 1, "log2");
-static_assert(log2(3) == 1, "log2");
-static_assert(log2(4) == 2, "log2");
-
-template <typename T>
-constexpr __host__ __device__ bool isPowerOf2(T v) {
- static_assert(std::is_integral<T>::value, "");
- return (v && !(v & (v - 1)));
-}
-
-static_assert(isPowerOf2(2048), "isPowerOf2");
-static_assert(!isPowerOf2(3333), "isPowerOf2");
-
-template <typename T>
-constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
- static_assert(std::is_integral<T>::value, "");
- return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1)));
-}
-
-static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
-
-static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
-static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
-
-static_assert(
- nextHighestPowerOf2(1536000000u) == 2147483648u,
- "nextHighestPowerOf2");
-static_assert(
- nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
- "nextHighestPowerOf2");
-
-template <typename T>
-constexpr __host__ __device__ T nextLowestPowerOf2(T v) {
- static_assert(std::is_integral<T>::value, "");
- return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v))));
-}
-
-static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2");
-
-static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2");
-static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2");
-
-inline __host__ __device__ bool isPointerAligned(const void* p, int align) {
- return reinterpret_cast<uintptr_t>(p) % align == 0;
-}
-
-// Returns the increment needed to aligned the pointer to the next highest
-// aligned address
-template <int Align>
-inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
- static_assert(isPowerOf2(Align), "");
- uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1));
- return diff == 0 ? 0 : uint32_t(Align) - diff;
-}
-
-constexpr int32_t kWarpSize = 32;
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-// f16 vector types
-struct __align__(2) f16x1 {
- __half vals[1];
-};
-
-struct __align__(4) f16x2 {
- __half vals[2];
-};
-
-struct __align__(8) f16x4 {
- __half vals[4];
-};
-
-struct __align__(16) f16x8 {
- __half vals[8];
-};
-
-// bf16 vector types
-struct __align__(2) bf16x1 {
- __nv_bfloat16 vals[1];
-};
-
-struct __align__(4) bf16x2 {
- __nv_bfloat16 vals[2];
-};
-
-struct __align__(8) bf16x4 {
- __nv_bfloat16 vals[4];
-};
-
-struct __align__(16) bf16x8 {
- __nv_bfloat16 vals[8];
-};
-
-// bf162 vector types
-struct __align__(4) bf16x2x1 {
- __nv_bfloat162 vals[1];
-};
-
-struct __align__(8) bf16x2x2 {
- __nv_bfloat162 vals[2];
-};
-
-struct __align__(16) bf16x2x4 {
- __nv_bfloat162 vals[4];
-};
-
-struct __align__(16) bf16x2x4_u32 {
- uint32_t vals[4];
-};
-
-struct __align__(8) bf16x2x2_u32 {
- uint32_t vals[2];
-};
-
-struct __align__(4) bf16x2x1_u32 {
- uint32_t vals[1];
-};
-
-template <typename T, int N>
-struct __align__(sizeof(T) * N) VectorType {
- T vals[N];
-};
-
-// from
-// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
-inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
- bf16x2x4 result;
- constexpr int kElements = 8;
-
- uint32_t* h = reinterpret_cast<uint32_t*>(&result);
- uint32_t const source_i4s = source;
-
- // First, we extract the i4s and construct an intermediate fp16 number.
- static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
- static constexpr uint32_t MASK = 0x000f000f;
- static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
-
- // We don't have enough mantissa to remove as much shift overhead as FP16, so
- // we must loop. No shift needed for first item.
- uint32_t i4s = source_i4s;
- asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
- : "=r"(h[0])
- : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
-#pragma unroll
- for (int ii = 1; ii < kElements / 2; ++ii) {
- i4s >>= 4; // or is it 8?
- // (i4s & 0x000f000f) | 0x43004300
- asm volatile(
- "lop3.b32 %0, %1, %2, %3, %4;\n"
- : "=r"(h[ii])
- : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
- }
-
- // This is the BF16 {-136, -136} represented as an integer.
- static constexpr uint32_t BF16_BIAS = 0xC308C308;
- static constexpr uint32_t BF16_ONE = 0x3F803F80;
-
-// Finally, we construct the output numbers.
-#pragma unroll
- for (int ii = 0; ii < kElements / 2; ++ii) {
- // Since this section is for Ampere+, we use bf16 fma to do the bias
- // subtraction
- asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
- : "=r"(h[ii])
- : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
- }
-
- return result;
-}
-
-
-
-enum class KReductionType {
- // No k-reduction is needed between blocks as the number of k-tiles processed
- // per block are exact and we can directly write the output
- None,
-};
-
-// Loads the A matrix in 16-bit standard m x k row major layout, and writes
-// the C matrix in 16-bit standard m x n row major layout:
-//
-// size [m][k]
-template <int KTilesPerWarp, KReductionType ReduceType = KReductionType::None>
-struct ALayout_RM {
- static constexpr int32_t kMTileSize = 16;
- static constexpr int32_t kNTileSize = 8;
- static constexpr int32_t kKTileSize = 16;
-
- static __device__ void load(
- const void* A,
- int32_t m,
- int32_t k,
- int32_t mTiles,
- int32_t mTile,
- int32_t kTiles,
- int32_t kTileStart,
- int32_t laneId,
- bf16x2x4_u32 out[KTilesPerWarp]) {
- auto mLane = mTile * kMTileSize + (laneId / 4);
- auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
-
- // access
- // [mTile * kMTileSize + (laneId / 4)]
- // [kTileStart * kKTileSize + (laneId % 4) * 2]
- auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
-
- auto aPtrPlus8Rows = aPtr + 8 * k;
-
- bool m0InBounds = mLane < m;
- bool m1InBounds = (mLane + 8) < m;
-
-#pragma unroll
- for (int i = 0; i < KTilesPerWarp; ++i) {
- out[i].vals[0] = m0InBounds
- ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
- : uint32_t(0);
- out[i].vals[1] = m1InBounds
- ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
- : uint32_t(0);
-
- out[i].vals[2] = m0InBounds
- ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 8)
- : uint32_t(0);
- out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
- aPtrPlus8Rows + i * kKTileSize + 8)
- : uint32_t(0);
- }
- }
-
- static __device__ void store(
- void* C,
- int32_t m,
- int32_t n,
- int32_t mOutTiles,
- int32_t mTile,
- int32_t nOutTiles,
- int32_t nTile,
- int32_t laneId,
- const float4& out) {
- static_assert(ReduceType == KReductionType::None, "");
-
- if constexpr (ReduceType == KReductionType::None) {
- // sum.x / sum.y are written at
- // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
- // sum.z / sum.w are written at
- // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
- // i.e., same columns, different row.
- int outRow = mTile * kMTileSize + (laneId / 4);
- int outCol = nTile * kNTileSize + (laneId % 4) * 2;
-
- // Pointer where sum.x / sum.y is written
- auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
-
- auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
- auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
-
- if (outRow < m) {
- *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01;
- }
-
- // sum.z, sum.w at +8 rows from cPtr
- if (outRow + 8 < m) {
- *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
- }
- }
- }
-};
-
-template <int KTilesPerWarp, int InnerKTiles, int QGroupSize>
-struct BLayout_TC_int4 {
- static constexpr int32_t kInnerKTiles = InnerKTiles;
- static constexpr int32_t kMTileSize = 16;
- static constexpr int32_t kNTileSize = 8;
- static constexpr int32_t kKTileSize = 16;
-
- static __device__ void load(
- // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
- // n / 8: n-tiles (n8)
- // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16)
- // 32: value per warp lane
- // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
- // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest value)
- // 4 k-tiles packed is a uint32x2 (64 bits)
- // 8 k-tiles packed is a uint32x4 (128 bits)
- const void* __restrict__ B,
- // size [k / qGroupSize][n][2]
- // Contains the scale and zero point of each of the quantized int4 values
- // within B
- // v_reconstructed = (bf16(B_int4_val) * scale) - zero
- const void* __restrict__ quantizationInfo,
- int32_t n,
- int32_t k,
- int32_t nTiles,
- int32_t nTile,
- int32_t kTiles,
- int32_t kTileStart,
- int32_t laneId,
- bf16x2x4_u32 out[KTilesPerWarp / InnerKTiles][InnerKTiles / 2]) {
- // offset [nTile][kTileStart / InnerKTiles][laneId][0]
- auto bPtr = reinterpret_cast<const int32_t*>(B) +
- (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) *
- kWarpSize) +
- laneId) *
- (InnerKTiles / 2);
-
- int32_t b_int4[KTilesPerWarp / InnerKTiles][InnerKTiles / 2];
-
-#pragma unroll
- for (int i = 0; i < KTilesPerWarp / InnerKTiles; ++i) {
- auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2);
-
- if constexpr (InnerKTiles == 2) {
- b_int4[i][0] = bPtrCur[0];
- }
-
- if constexpr (InnerKTiles == 4) {
-
- // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
- // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1])
- // : "l"(bPtrCur));
-
- int2 load8 = reinterpret_cast<const int2*>(bPtrCur)[0];
- b_int4[i][0] = load8.x;
- b_int4[i][1] = load8.y;
- }
-
- if constexpr (InnerKTiles == 8) {
-
- // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
- // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), "=r"(b_int4[i][2]), "=r"(b_int4[i][3])
- // : "l"(bPtrCur));
-
- int4 load16 = reinterpret_cast<const int4*>(bPtrCur)[0];
- b_int4[i][0] = load16.x;
- b_int4[i][1] = load16.y;
- b_int4[i][2] = load16.z;
- b_int4[i][3] = load16.w;
- }
- }
-
- // Load needed info for dequantization
-
- static_assert(isPowerOf2(QGroupSize), "");
- static_assert(isEvenDivisor(QGroupSize, kKTileSize), "");
- // smallest quantization group size is 32 (2 k-tiles are packed in an int32)
- static_assert(QGroupSize >= kKTileSize * 2, "");
- constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize);
- // a q-group could be larger than what we are handling in a single warp
- constexpr int kNumQGroups = (KTilesPerWarp / kKTilesPerQGroup) < 1
- ? 1
- : (KTilesPerWarp / kKTilesPerQGroup);
-
- __nv_bfloat162 qScaleAndZero[kNumQGroups];
- {
- int32_t laneN = nTile * kNTileSize + (laneId / 4);
- int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
-
- int32_t n = nTiles * kNTileSize;
-
- // offset [qScale_kGroup][qScale_n][0]
- auto qInfoPtr = reinterpret_cast<const __nv_bfloat16*>(quantizationInfo) +
- (groupStart * n + laneN) * 2;
-
-#pragma unroll
- for (int i = 0; i < kNumQGroups; ++i) {
- qScaleAndZero[i] =
- *reinterpret_cast<const __nv_bfloat162*>(qInfoPtr + i * n * 2);
- }
- }
-
- //
- // De-quantize int4 values to bf16. Values are dequantized as truly int4
- // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
- //
- {
- // FIXME: does this negatively affect register counts, or will nvcc
- // move this expansion (and data loads above) closer to the point of use?
- __nv_bfloat162 qScale[kNumQGroups];
- __nv_bfloat162 qZero[kNumQGroups];
-
-#pragma unroll
- for (int i = 0; i < kNumQGroups; ++i) {
- qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x);
- qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y);
- }
-
-#pragma unroll
- for (int i = 0; i < KTilesPerWarp / InnerKTiles; ++i) {
-#pragma unroll
- for (int j = 0; j < InnerKTiles / 2; ++j) {
- bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]);
-
- int curKTile = i * InnerKTiles + j * 2;
- int curQGroup = (curKTile * kKTileSize) / QGroupSize;
-
- // The dequantized values in `v` for a given lane have the same n
- // dimension (the B tensor core layout has all values in the same
- // thread along the same n) but different k dimension, but all are
- // guaranteed to occur within the same quantization group, so we need
- // only load a single scale + zero to cover what this lane has
-#pragma unroll
- for (int k = 0; k < 4; ++k) {
- v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]);
- }
-
- // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and can't
- // be used as a 32-bit asm register argument for `mma`
- static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
- std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
- }
- }
- }
- }
-};
-
-template <
- typename ALayout,
- typename BLayout,
- typename CLayout,
- int Warps,
- int KTilesPerWarp>
-__global__ __launch_bounds__(Warps * kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
- // Data for the A matrix, loaded as per ALayout
- const void* __restrict__ A,
-
- // Data for the B matrix, loaded as per BLayout
- const void* __restrict__ B,
-
- // Optional quantization data for dequantizing B, loaded as per BLayout
- const void* __restrict__ B_quantizationInfo,
-
- // Output data for the C matrix, stored as per CLayout
- void* __restrict__ C,
-
- // The size of the matrix multiplication
- int32_t m,
- int32_t n,
- int32_t k,
-
- // The size of the matrix multiplication, in multiples of our TC tile size
- int32_t mTiles,
- int32_t nTiles,
- int32_t kTiles) {
- constexpr int32_t kMTileSize = 16;
- constexpr int32_t kNTileSize = 8;
- constexpr int32_t kKTileSize = 16;
-
- static_assert(
- ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
- ALayout::kKTileSize == kKTileSize,
- "");
-
- static_assert(
- BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize &&
- BLayout::kKTileSize == kKTileSize,
- "");
-
- static_assert(
- CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize &&
- CLayout::kKTileSize == kKTileSize,
- "");
-
- constexpr int kInnerKTiles = BLayout::kInnerKTiles;
-
- // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads
- static_assert(
- kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, "");
-
- // We always process at least kInnerKTiles k-tiles back to back in a warp
- static_assert(KTilesPerWarp >= kInnerKTiles && isEvenDivisor(KTilesPerWarp, kInnerKTiles), "");
-
- auto warpId = threadIdx.y;
- auto laneId = threadIdx.x;
-
- int32_t mTile = blockIdx.z;
- int32_t nTile = blockIdx.y;
-
- float4 c{0.0f, 0.0f, 0.0f, 0.0f};
-
- // Requirement: kTiles must be an even multiple of Warps * KTilesPerWarp
- for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerWarp;
- kTileBase < kTiles;
- kTileBase += Warps * KTilesPerWarp) {
- //
- // Load data from A
- //
- bf16x2x4_u32 a[KTilesPerWarp];
- ALayout::load(A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
-
- //
- // Load data from B and de-quantize as needed
- // Each k-tile is bf16x2x2
- //
- bf16x2x4_u32 b[KTilesPerWarp / kInnerKTiles][kInnerKTiles / 2];
- BLayout::load(
- B,
- B_quantizationInfo,
- n,
- k,
- nTiles,
- nTile,
- kTiles,
- kTileBase,
- laneId,
- b);
-
- //
- // Now, perform the matrix multiplication
- //
-
- // We accumulate across k-tiles here
-#pragma unroll
- for (int i = 0; i < KTilesPerWarp / kInnerKTiles; ++i) {
- static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, "");
-#pragma unroll
- for (int j = 0; j < kInnerKTiles / 2; ++j) {
- // We don't simply accumulate into `c` as this creates a too-strong
- // execution dependency. Instead, we only periodically accumulate into
- // `c`
- float4 cTmp[2];
-
-#pragma unroll
- for (int k = 0; k < 2; ++k) {
- cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
- }
-
-#pragma unroll
- for (int k = 0; k < 2; ++k) {
- asm volatile(
- "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
- "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
- : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
- : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]),
- "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]),
- "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]),
- "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]),
- "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
- "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
- "f"(cTmp[k].x),
- "f"(cTmp[k].y),
- "f"(cTmp[k].z),
- "f"(cTmp[k].w));
- }
-
-#pragma unroll
- for (int k = 0; k < 2; ++k) {
- c.x += cTmp[k].x;
- c.y += cTmp[k].y;
- c.z += cTmp[k].z;
- c.w += cTmp[k].w;
- }
- }
- }
- } // for all kTiles
-
- //
- // Reduce independent k-tiles (same m/n) across warps
- //
- __shared__ float4 smem_sum[Warps][kWarpSize];
-
- smem_sum[warpId][laneId] = c;
-
- __syncthreads();
-
- if (warpId == 0) {
- float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
-
- // Reduce across the block in the first warp
- for (int i = 0; i < Warps; ++i) {
- float4 v = smem_sum[i][laneId];
- sum_f32.x += v.x;
- sum_f32.y += v.y;
- sum_f32.z += v.z;
- sum_f32.w += v.w;
- }
-
- // Write the reduced result (in the first warp) into the output
- CLayout::store(
- C,
- m,
- n,
- mTiles,
- mTile,
- // n for C output becomes k for A input, so for m16n8k16,
- // we need to halve the tiles
- nTiles / 2,
- nTile,
- laneId,
- sum_f32);
- }
-}
-
-
-template <
- typename ALayout,
- typename BLayout,
- typename CLayout,
- int Warps,
- int KTilesPerWarp>
-void launch_tinygemm_kernel(
- const torch::Tensor& A,
- const torch::Tensor& B,
- const torch::Tensor* qScaleAndZeros, /* optional */
- torch::Tensor& C_final,
- int32_t mTiles,
- int32_t nTiles,
- int32_t kTiles,
- int32_t m,
- int32_t n,
- int32_t k,
- cudaStream_t stream) {
- TORCH_CHECK(
- kTiles >= (Warps * KTilesPerWarp) &&
- isEvenDivisor(kTiles, Warps * KTilesPerWarp));
-
- TORCH_CHECK(
- KTilesPerWarp >= BLayout::kInnerKTiles && isPowerOf2(KTilesPerWarp));
-
- // After intra-block reduction across the k dimension, we are left with this
- // many tiles
- // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp);
- int32_t postKernelKTiles = 1; // we loop
-
- auto grid = dim3(postKernelKTiles, nTiles, mTiles);
- auto block = dim3(kWarpSize, Warps);
-
- auto func =
- tinygemm_m16n8k16_chunk_kernel<ALayout, BLayout, CLayout, Warps, KTilesPerWarp>;
-
- func<<<grid, block, 0, stream>>>(
- A.data_ptr(),
- B.data_ptr(),
- qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr,
- C_final.data_ptr(),
- m,
- n,
- k,
- mTiles,
- nTiles,
- kTiles);
-
- cudaFuncAttributes funcAttr;
- C10_CUDA_CHECK(cudaFuncGetAttributes(
- &funcAttr,
- func));
-}
-
-// FIXME: parallelize better, smem staging etc?
-template <int InnerKTiles>
-__global__ void matrix_to_m16n8k16_Bint4_layout(
- // size [n][k]
- const at::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits> in,
- // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2]
- at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> out) {
- // int4 values are packed into int32 values, which require at least 8. Given
- // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of
- // innermost k-tiles that we can use is 2.
- static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
-
- constexpr int32_t kMTileSize = 16;
- constexpr int32_t kNTileSize = 8;
- constexpr int32_t kKTileSize = 16;
-
- // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
- auto kOuterTile = blockIdx.x;
- auto nTile = blockIdx.y;
- auto t = threadIdx.x;
-
- // Two k-tiles are packed into an int32 at a time
-#pragma unroll
- for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
- // n dimension that this lane loads from
- auto n0 = nTile * kNTileSize + (t / 4);
-
- bool n0Valid = n0 < in.size(0);
-
- int32_t ks[8];
-
- auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize;
- ks[0] = kBase0 + (t % 4) * 2;
- ks[1] = ks[0] + 1;
- ks[2] = ks[0] + 8;
- ks[3] = ks[0] + 8 + 1;
-
- auto kBase1 = kBase0 + kKTileSize;
- ks[4] = kBase1 + (t % 4) * 2;
- ks[5] = ks[4] + 1;
- ks[6] = ks[4] + 8;
- ks[7] = ks[4] + 8 + 1;
-
- auto pIn = &in[n0][0];
-
- uint32_t v[8];
-#pragma unroll
- for (int i = 0; i < 8; ++i) {
- v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint32_t(0);
- }
-
- int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) |
- (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0];
-
- // inner k-tiles pack two at a time
- out[nTile][kOuterTile][t][innerKTile / 2] = pack;
- }
-}
-
-#endif
-
-
-torch::Tensor _weight_int4pack_mm_cuda(
- const torch::Tensor& A,
- const torch::Tensor& B,
- int64_t qGroupSize,
- const torch::Tensor& qScaleAndZeros) {
- c10::cuda::CUDAGuard g(A.device());
- auto stream = at::cuda::getCurrentCUDAStream();
-
- TORCH_CHECK(
- A.device() == B.device() && A.device() == qScaleAndZeros.device());
-
- constexpr int32_t kMTileSize = 16;
- constexpr int32_t kNTileSize = 8;
- constexpr int32_t kKTileSize = 16;
-
- // row major layout
- auto m = A.size(0);
- auto mTiles = divUp(m, kMTileSize);
-
- // tensor core layout
- auto nTiles = B.size(0);
- auto n = nTiles * kNTileSize;
-
- // row major layout
- auto k = A.size(1);
- auto kTiles = divUp(k, kKTileSize);
-
- // The number of inner k tiles is the innermost dimension of times 2
- // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4
- // packed into 1 int32 for int4 B
- auto B_innerKTiles = B.size(3) * 2;
- TORCH_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8);
-
- // A is standard row major
- TORCH_CHECK(A.dtype() == torch::kBFloat16);
- TORCH_CHECK(A.is_contiguous());
- TORCH_CHECK(A.dim() == 2);
-
- // B has B_innerKTiles k-tiles in the innermost dimension
- TORCH_CHECK(B.dtype() == torch::kInt32);
- TORCH_CHECK(B.is_contiguous());
- TORCH_CHECK(B.dim() == 4);
- TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
- TORCH_CHECK(B.size(2) == kWarpSize);
-
- // Validate the scale and zero point tensor for dequantization
- TORCH_CHECK(qScaleAndZeros.dim() == 3);
- auto numQGroups = qScaleAndZeros.size(0);
- TORCH_CHECK(k / numQGroups == qGroupSize);
- TORCH_CHECK(qScaleAndZeros.size(1) == n);
- TORCH_CHECK(qScaleAndZeros.size(2) == 2);
-
- // Output is a standard row-major matrix
- auto C_final = torch::empty(
- {m, n},
- torch::TensorOptions().dtype(at::kBFloat16).device(A.device()));
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
-#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
- do { \
- using ACLayout = ALayout_RM<K_TILES_PER_WARP, REDUCE_TYPE>; \
- \
- TORCH_CHECK( \
- K_TILES_PER_WARP >= B_innerKTiles && \
- isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles)); \
- \
- switch (B_innerKTiles) { \
- case 2: \
- if constexpr (K_TILES_PER_WARP >= 2) { \
- using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 2, Q_GROUP_SIZE>; \
- launch_tinygemm_kernel< \
- ACLayout, \
- BLayout, \
- ACLayout, \
- WARPS, \
- K_TILES_PER_WARP>( \
- A, \
- B, \
- &qScaleAndZeros, \
- C_final, \
- mTiles, \
- nTiles, \
- kTiles, \
- m, \
- n, \
- k, \
- stream); \
- } \
- break; \
- case 4: \
- if constexpr (K_TILES_PER_WARP >= 4) { \
- using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 4, Q_GROUP_SIZE>; \
- launch_tinygemm_kernel< \
- ACLayout, \
- BLayout, \
- ACLayout, \
- WARPS, \
- K_TILES_PER_WARP>( \
- A, \
- B, \
- &qScaleAndZeros, \
- C_final, \
- mTiles, \
- nTiles, \
- kTiles, \
- m, \
- n, \
- k, \
- stream); \
- } \
- break; \
- case 8: \
- if constexpr (K_TILES_PER_WARP >= 8) { \
- using BLayout = BLayout_TC_int4<K_TILES_PER_WARP, 8, Q_GROUP_SIZE>; \
- launch_tinygemm_kernel< \
- ACLayout, \
- BLayout, \
- ACLayout, \
- WARPS, \
- K_TILES_PER_WARP>( \
- A, \
- B, \
- &qScaleAndZeros, \
- C_final, \
- mTiles, \
- nTiles, \
- kTiles, \
- m, \
- n, \
- k, \
- stream); \
- } \
- break; \
- default: \
- break; \
- } \
- } while (false)
-
-#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \
- do { \
- switch (qGroupSize) { \
- case 32: \
- RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \
- break; \
- case 64: \
- RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \
- break; \
- case 128: \
- RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \
- break; \
- case 256: \
- RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \
- break; \
- } \
- } while (false)
-
- // These are the only versions handled at the moment
- TORCH_CHECK(k == 32 || k == 64 || isEvenDivisor(k, 1024));
- TORCH_CHECK(
- qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
- qGroupSize == 256);
-
- if (k == 32) {
- HANDLE_Q_GROUP(1, 2, KReductionType::None);
- } else if (k == 64) {
- HANDLE_Q_GROUP(1, 4, KReductionType::None);
- } else {
- // Handle k = 1024 (8 * 8 * kKTileSize) at a time
- HANDLE_Q_GROUP(8, 8, KReductionType::None);
- }
-
-#undef RUN_GEMM
-
- return C_final;
-#endif
- TORCH_CHECK(false, "_weight_int4pack_mm_cuda is not available for build.")
- return C_final;
-}
-
-// input is [n][k] (int32 dtype)
-// output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
-torch::Tensor _convert_weight_to_int4pack_cuda(
- const torch::Tensor& in,
- int64_t innerKTiles) {
- c10::cuda::CUDAGuard g(in.device());
- auto stream = at::cuda::getCurrentCUDAStream();
-
- TORCH_CHECK(in.dim() == 2);
- TORCH_CHECK(in.dtype() == torch::kInt32);
- TORCH_CHECK(in.is_contiguous());
-
- // At least 2 k-tiles need to be packed back to back in the innermost
- // dimension, as the m16n8k16 tensor core tile presents 4 scalar values for
- // the B matrix, but the minimum word size for the packed format is 4 bytes
- // (int32). 4 inner K-tiles = 8 byte load, 8 inner k-tiles = 16 byte load
- // which is the maximum vectorized load/store size
- TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
-
- constexpr int32_t kMTileSize = 16;
- constexpr int32_t kNTileSize = 8;
- constexpr int32_t kKTileSize = 16;
-
- auto nTiles = divUp(in.size(0), kNTileSize);
-
- // k-tiles are packed back to back in the innermost dimension in order to
- // allow for 4/8/16 byte loads
- TORCH_CHECK(isEvenDivisor(in.size(1), innerKTiles * kKTileSize));
- // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize
- auto kSuperTiles = divUp(in.size(1), innerKTiles * kKTileSize);
-
- // each block handles `innerKTiles` k-tiles.
- // 2 k-tiles are a single int32
- auto out = torch::empty(
- {nTiles, kSuperTiles, 32, innerKTiles / 2},
- torch::TensorOptions().dtype(torch::kInt32).device(in.device()));
-
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
- dim3 grid(kSuperTiles, nTiles);
-
- if (innerKTiles == 2) {
- matrix_to_m16n8k16_Bint4_layout<2><<<grid, kWarpSize, 0, stream>>>(
- in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
- out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
- } else if (innerKTiles == 4) {
- matrix_to_m16n8k16_Bint4_layout<4><<<grid, kWarpSize, 0, stream>>>(
- in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
- out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
- } else if (innerKTiles == 8) {
- matrix_to_m16n8k16_Bint4_layout<8><<<grid, kWarpSize, 0, stream>>>(
- in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
- out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
- }
-
- return out;
-#endif
- TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is not available for build.")
- return out;
-}
-
-
-} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 36a2f4e..5ffad03 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4019,14 +4019,6 @@
dispatch:
CUDA: _int_mm_out_cuda
-- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
- dispatch:
- CUDA: _convert_weight_to_int4pack_cuda
-
-- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
- dispatch:
- CUDA: _weight_int4pack_mm_cuda
-
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
python_module: sparse
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8e84155..d460d19 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -54,7 +54,6 @@
aten::_convert_indices_from_coo_to_csr.out
aten::_convert_indices_from_csr_to_coo
aten::_convert_indices_from_csr_to_coo.out
-aten::_convert_weight_to_int4pack
aten::_convolution
aten::_convolution.out
aten::_copy_from
@@ -588,7 +587,6 @@
aten::_values
aten::_values_copy
aten::_values_copy.out
-aten::_weight_int4pack_mm
aten::_weight_norm_interface
aten::_weight_norm_interface.out
aten::_weight_norm_interface_backward
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 17217d9..a1e119a 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -29,7 +29,7 @@
all_types, all_types_and_complex_and, floating_and_complex_types, integral_types,
floating_and_complex_types_and, floating_types_and, complex_types,
)
-from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, tf32_on_and_off, _get_magma_version, \
+from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version
from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum
@@ -5771,123 +5771,6 @@
r"Expected result.size\(0\) to be 17 but got 16",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))
- def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16):
- assert w.dim() == 2
- w = w.transpose(0, 1).contiguous()
- assert q_group_size > 1
- assert w.shape[-1] % q_group_size == 0
-
- to_quant = w.reshape(-1, q_group_size)
- assert torch.isnan(to_quant).sum() == 0
-
- max_val = to_quant.amax(dim=1, keepdim=True)
- min_val = to_quant.amin(dim=1, keepdim=True)
- max_int = 2 ** n_bit - 1
- min_int = 0
- scales = (max_val - min_val).clamp(min=1e-6) / max_int
- assert torch.isnan(scales).sum() == 0
-
- zeros = min_val + scales * (2 ** (n_bit - 1))
- assert torch.isnan(zeros).sum() == 0
-
- out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
- assert torch.isnan(out).sum() == 0
-
- out = out.to(dtype=torch.int32).reshape(w.shape)
-
- # Scales and zeros for the same q-group should be contiguous, so we can
- # load as a 32-bit word
- scales = scales.view(w.shape[0], -1)
- zeros = zeros.view(w.shape[0], -1)
- scales_and_zeros = (
- torch.cat(
- [
- scales.reshape(scales.size(0), scales.size(1), 1),
- zeros.reshape(zeros.size(0), zeros.size(1), 1),
- ],
- 2,
- ).transpose(0, 1).contiguous()
- )
-
- return out, scales_and_zeros
-
- @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
- @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
- @unittest.skipIf(not SM80OrLater, "need sm_80")
- @onlyCUDA
- @parametrize("m", [32, 64])
- @parametrize("k", [32, 64])
- @parametrize("n", [48, 64])
- def test__int4_mm(self, device, m, k, n):
- if TEST_WITH_ROCM:
- self.skipTest("_int4_mm not compiled for ROCM")
-
- q_group = 32
- inner_k_tiles = 2
-
- torch.manual_seed(1)
- a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
- b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
-
- def convert_weight_to_int4pack(b):
- b_int32, b_scales_and_zeros = self._group_quantize_tensor(
- b, n_bit=4, q_group_size=q_group
- )
- b_int4pack = torch._convert_weight_to_int4pack(
- b_int32, inner_k_tiles
- )
-
- return b_int4pack, b_scales_and_zeros
-
- def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
- return torch._weight_int4pack_mm(
- a, b_int4pack, q_group, b_scales_and_zeros
- )
-
- b_int4pack, b_scales_and_zeros = convert_weight_to_int4pack(b)
- res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
- ref = torch.mm(a, b)
-
- mean_err = ((res - ref).abs() / ref).mean()
- self.assertTrue(mean_err < 0.05)
-
- @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
- @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
- @unittest.skipIf(not SM80OrLater, "need sm_80")
- @onlyCUDA
- @parametrize("m", [32, 64])
- @parametrize("k", [32, 64])
- @parametrize("n", [48, 64])
- def test_compile_int4_mm(self, device, m, k, n):
- if TEST_WITH_ROCM:
- self.skipTest("_int4_mm not compiled for ROCM")
-
- q_group = 32
- inner_k_tiles = 2
-
- torch.manual_seed(1)
- a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
- b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
-
- b_int32, b_scales_and_zeros = self._group_quantize_tensor(
- b, n_bit=4, q_group_size=q_group
- )
-
- @torch.compile
- def int4_mm(a, b_int32, b_scales_and_zeros):
- b_int4pack = torch._convert_weight_to_int4pack(
- b_int32, inner_k_tiles
- )
- return torch._weight_int4pack_mm(
- a, b_int4pack, q_group, b_scales_and_zeros
- )
-
- res = int4_mm(a, b_int32, b_scales_and_zeros)
- ref = torch.mm(a, b)
-
- mean_err = ((res - ref).abs() / ref).mean()
- self.assertTrue(mean_err < 0.05)
-
@slowTest
@onlyNativeDeviceTypes
# bfloat16 doesn't have sufficient precision to pass this test
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 3afcc2e..98097f5 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -3216,41 +3216,6 @@
return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
-@register_meta([aten._convert_weight_to_int4pack])
-def meta__convert_weight_to_int4pack(w, inner_k_tiles):
- torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
- torch._check(
- w.dtype is torch.int32,
- lambda: f"expected w to be int32, got {w.dtype}",
- )
- n = w.size(0)
- k = w.size(1)
- return w.new_empty(
- (
- n // 8,
- k // (inner_k_tiles * 16),
- 32,
- inner_k_tiles // 2,
- ),
- dtype=torch.int32,
- )
-
-
-@register_meta([aten._weight_int4pack_mm])
-def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
- torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
- torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
- torch._check(
- x.dtype is torch.bfloat16,
- lambda: f"expected x to be bf16, got {x.dtype}",
- )
- torch._check(
- w.dtype is torch.int32,
- lambda: f"expected w to be int32, got {w.dtype}",
- )
- return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
-
-
@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
torch._check(