Sync from v0.13
This commit is contained in:
76
tests/kernels/quantization/nvfp4_utils.py
Normal file
76
tests/kernels/quantization/nvfp4_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert a.dtype == torch.uint8
|
||||
m, n = a.shape
|
||||
|
||||
# Vectorized nibble processing
|
||||
a_flat = a.flatten()
|
||||
high = (a_flat & 0xF0) >> 4 # Upper nibbles
|
||||
low = a_flat & 0x0F # Lower nibbles
|
||||
|
||||
# Combine nibbles for batch processing
|
||||
combined = torch.stack((low, high), dim=1).flatten()
|
||||
|
||||
# Vectorized sign and magnitude extraction
|
||||
signs = (combined & 0x08).to(torch.bool) # Sign bits
|
||||
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
|
||||
|
||||
# Device-aware lookup and sign application
|
||||
kE2M1 = kE2M1ToFloat.to(device=a.device)
|
||||
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def get_nvfp4_global_scale(a: torch.Tensor):
|
||||
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
|
||||
|
||||
|
||||
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||
a_global_scale = get_nvfp4_global_scale(a)
|
||||
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||
return a_quant, a_block_scale, a_global_scale
|
||||
127
tests/kernels/quantization/test_allspark_gemm.py
Normal file
127
tests/kernels/quantization/test_allspark_gemm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_K_ALIGN,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
|
||||
return (
|
||||
capability.to_int() >= min_capability and capability.to_int() <= max_capability
|
||||
)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 4, 8),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(48, 16, 24),
|
||||
(67, 13, 88),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
(1033, 9, 17),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
HAS_ZP_OPTS = [False, True]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_gptq_allspark_supported(80, 89),
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
m = m_factor
|
||||
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
|
||||
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
|
||||
|
||||
input = rand_data((m, k), dtype=dtype)
|
||||
weight = rand_data((k, n), dtype=dtype)
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, qw, s, zp = quantize_weights(
|
||||
weight, scalar_types.uint8b128, group_size, has_zp
|
||||
)
|
||||
|
||||
qw = qw.to(torch.uint8)
|
||||
if has_zp:
|
||||
zp = zp.to(dtype)
|
||||
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
|
||||
opcheck(
|
||||
torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.allspark_w8a16_gemm,
|
||||
(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(input, w_ref)
|
||||
torch.cuda.synchronize()
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
49
tests/kernels/quantization/test_awq.py
Normal file
49
tests/kernels/quantization/test_awq.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(
|
||||
torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
qzeros = torch.randint(
|
||||
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
|
||||
171
tests/kernels/quantization/test_awq_triton.py
Normal file
171
tests/kernels/quantization/test_awq_triton.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the AWQ Triton kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
reverse_order_tensor = torch.arange(
|
||||
t.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=t.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
t = t[:, reverse_order_tensor] & 0xF
|
||||
return t
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||
) -> torch.Tensor:
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
iweights = reverse_awq_order(iweights)
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
qweight_dtype = torch.int32
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = qweight_cols * 8
|
||||
scales_dtype = torch.float16
|
||||
zeros_rows = scales_rows
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
|
||||
zeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
|
||||
torch.isnan(iweights_triton)
|
||||
)
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
torch.testing.assert_close(iweights_triton, iweights_torch)
|
||||
|
||||
|
||||
# input - [N, K]
|
||||
# qweight - [K, M // 8]
|
||||
# qzeros - [K // G, M // 8]
|
||||
# scales - [K // G, M]
|
||||
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
|
||||
@pytest.mark.parametrize("K", [128])
|
||||
@pytest.mark.parametrize("M", [16, 24, 32])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
split_k_iters = splitK
|
||||
|
||||
input_rows = N
|
||||
input_cols = K
|
||||
input_dtype = torch.float32
|
||||
qweight_rows = input_cols
|
||||
qweight_cols = M // 8
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = M
|
||||
scales_dtype = torch.float32
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
|
||||
qweight = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
|
||||
)
|
||||
qzeros = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
|
||||
)
|
||||
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
|
||||
torch.isnan(output_triton)
|
||||
)
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
|
||||
torch.isnan(output_torch)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
|
||||
)
|
||||
207
tests/kernels/quantization/test_block_fp8.py
Normal file
207
tests/kernels/quantization/test_block_fp8.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 2050]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 512]
|
||||
M = [1, 7, 8, 83, 84, 4096]
|
||||
N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_cutlass_matmul():
|
||||
# Test simple case where weight.shape % 128 != 0,
|
||||
# like in DSV3 kv_a_proj_with_mqa
|
||||
M = 32
|
||||
N = 576
|
||||
K = 7168
|
||||
block_size = [128, 128]
|
||||
out_dtype = torch.bfloat16
|
||||
seed = 0
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
# Hopper requires row-major format for scales
|
||||
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
|
||||
|
||||
A_fp8, As = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=False
|
||||
)
|
||||
# CUTLASS uses column-major format for scales
|
||||
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = cutlass_scaled_mm(
|
||||
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
|
||||
)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_fp8_fnuz(),
|
||||
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
# only aligned sizes are supported by deepgemm
|
||||
if not should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
|
||||
):
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32, block_size=block_size)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
)
|
||||
|
||||
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
67
tests/kernels/quantization/test_block_int8.py
Normal file
67
tests/kernels/quantization/test_block_int8.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
w8a8_block_int8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
236
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
236
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(
|
||||
dtype: torch.dtype,
|
||||
b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor,
|
||||
):
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
|
||||
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(
|
||||
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device="cuda")
|
||||
b = torch.randn((n, k), device="cuda").t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
a = a * 5.0
|
||||
b = b * 5.0
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 512),
|
||||
(16, 256, 512),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 512),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 512),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(
|
||||
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(
|
||||
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
682
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
682
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
@@ -0,0 +1,682 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 496),
|
||||
(16, 256, 496),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 496),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 496),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
PER_TOKEN_GROUP_SHAPE = (1, -1)
|
||||
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def group_scale_helper(shape, group_shape):
|
||||
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
torch.bfloat16,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
# N and K correspond to the size of the weight matrix and likely to be multiples
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n,)
|
||||
|
||||
baseline_q = (
|
||||
scale_a.to(device="cpu")
|
||||
* scale_b.to(device="cpu")
|
||||
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
|
||||
).to(dtype=out_dtype, device="cuda")
|
||||
|
||||
out = ops.cutlass_scaled_mm(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
|
||||
)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(
|
||||
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
|
||||
):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
else:
|
||||
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
||||
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
|
||||
)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
|
||||
)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
|
||||
)
|
||||
else:
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
|
||||
)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
|
||||
whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
|
||||
a = whole_a[0:m, 0:k]
|
||||
b = whole_b[0:k, 0:n]
|
||||
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class CutlassLayer(torch.nn.Module):
|
||||
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||
super().__init__()
|
||||
self.b = b
|
||||
self.scale_a = scale_a
|
||||
self.scale_b = scale_b
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm(
|
||||
a, self.b, self.scale_a, self.scale_b, self.out_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
a = to_int8(torch.randn((m, k), device="cuda"))
|
||||
b = to_int8(torch.randn((n, k), device="cuda").t())
|
||||
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out = model(a)
|
||||
out.zero_()
|
||||
g.replay()
|
||||
|
||||
baseline = torch.mm(
|
||||
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
|
||||
).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_group_gemm(
|
||||
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
|
||||
# Create separate A, B, C tensors for each group
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
alignment = 16 # 128 // 8
|
||||
# For variation, each group has dimensions
|
||||
n_g = alignment * random.randint(1, 64)
|
||||
k_g = alignment * random.randint(1, 64)
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][0] = m_g
|
||||
problem_sizes[g][1] = n_g
|
||||
problem_sizes[g][2] = k_g
|
||||
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_tensors_stacked = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_scales_tensors[g]
|
||||
)
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty(
|
||||
(num_experts, n_b_scales), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros(
|
||||
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
ab_strides = torch.full(
|
||||
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
ops.cutlass_moe_mm(
|
||||
out_tensors_stacked,
|
||||
a_tensors_stacked,
|
||||
b_tensors_stacked,
|
||||
a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
ab_strides,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
329
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
329
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the CUTLASS W4A8 kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
pack_cols,
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
# TODO(czhu): get supported schedules from fn
|
||||
SCHEDULES = [
|
||||
"128x16_1x1x1",
|
||||
"256x16_1x1x1",
|
||||
"128x32_1x1x1",
|
||||
"256x32_1x1x1",
|
||||
"128x64_1x1x1",
|
||||
"256x64_1x1x1",
|
||||
"128x128_1x1x1",
|
||||
"256x128_1x1x1",
|
||||
"128x256_1x1x1",
|
||||
"128x256_2x1x1",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: torch.Tensor
|
||||
w_ch_s: torch.Tensor
|
||||
w_tok_s: torch.Tensor
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32,
|
||||
)
|
||||
for w_type in [scalar_types.int4]
|
||||
# TODO(czhu): fp16 out type
|
||||
for o_type in [torch.bfloat16]
|
||||
),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# since scales are cast to fp8, we need to compute w_ref this way
|
||||
w_ref = (
|
||||
(w_q).to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# bit mask prevents sign extending int4 when packing
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))
|
||||
|
||||
return w_ref, w_q_packed, w_s_packed, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
print(
|
||||
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
||||
)
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
w = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
|
||||
)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
|
||||
|
||||
def mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
output_ref = torch._scaled_mm(
|
||||
tensors.a_ref.to(types.act_type),
|
||||
tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major
|
||||
tensors.w_tok_s.unsqueeze(1),
|
||||
tensors.w_ch_s.unsqueeze(0),
|
||||
out_dtype=types.output_type,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
@pytest.mark.parametrize("schedule", SCHEDULES)
|
||||
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
group_sizes = [128]
|
||||
for group_size in group_sizes:
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class W4A8Layer(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
def test_w4a8_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
b = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
wtype = scalar_types.int4
|
||||
stype = torch.float8_e4m3fn
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
|
||||
)
|
||||
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
b_q=w_q_packed,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=w_ch_s,
|
||||
a_token_scales=w_tok_s,
|
||||
)
|
||||
|
||||
output_ref = torch._scaled_mm(
|
||||
a,
|
||||
w_ref.to(a.dtype).t().contiguous().t(), # col major
|
||||
w_tok_s.unsqueeze(1),
|
||||
w_ch_s.unsqueeze(0),
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES)
|
||||
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
|
||||
"""
|
||||
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
|
||||
The CUTLASS kernels expect signed int4 packed to int32.
|
||||
This tests checks that the runtime int4b8 -> signed int4 conversion
|
||||
matches the offline conversion step exactly.
|
||||
"""
|
||||
_, N, K = shape
|
||||
# random weights packed to int32
|
||||
t = torch.randint(
|
||||
low=torch.iinfo(torch.int32).min,
|
||||
high=torch.iinfo(torch.int32).max + 1,
|
||||
size=(N, K // 8),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# compute reference
|
||||
unpacked = unpack_quantized_values_into_int32(
|
||||
t.clone(), scalar_types.uint4b8, packed_dim=1
|
||||
)
|
||||
unpacked = unpacked - 8 # int4b8 -> signed int4
|
||||
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
|
||||
|
||||
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
|
||||
|
||||
assert torch.equal(ref, out)
|
||||
assert not torch.equal(ref, t)
|
||||
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
"""
|
||||
Quantize weights into W4 and compute reference dequantized weights.
|
||||
|
||||
Encoding/reordering of weights and packing of scales is deferred
|
||||
until after all experts are combined.
|
||||
"""
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# Since scales are later cast to fp8, recompute w_ref in atype here.
|
||||
w_ref = (
|
||||
w_q.to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# Bit mask prevents sign extension of int4 when packing.
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
# Make weights row-major (N, K).
|
||||
w_q = w_q.t().contiguous()
|
||||
|
||||
return w_ref, w_q, w_s.to(atype), w_zp
|
||||
|
||||
|
||||
def cutlass_preprocess(
|
||||
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Reorder/encode expert weights and pack scales.
|
||||
|
||||
Returns:
|
||||
w_q_packed: Packed/encoded int4 weights for all experts.
|
||||
w_s_packed: Packed fp8 scales for all experts.
|
||||
packed_layout: Layout/stride metadata for grouped GEMM.
|
||||
"""
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
|
||||
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
|
||||
torch.stack(w_q_experts)
|
||||
) # expects dim 3
|
||||
return w_q_packed, w_s_packed, packed_layout
|
||||
|
||||
|
||||
GROUP_SIZE = 128
|
||||
# (num_experts, N, K)
|
||||
TEST_SHAPES = [
|
||||
(8, 512, 2048),
|
||||
(8, 2048, 2048),
|
||||
(64, 512, 1024),
|
||||
(64, 2048, 2048),
|
||||
(4, 2048, 768),
|
||||
(8, 768, 2048),
|
||||
(64, 1536, 2048),
|
||||
(128, 8192, 4096), # test overflow int32
|
||||
]
|
||||
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoETestSetup:
|
||||
num_experts: int
|
||||
K: int
|
||||
N: int
|
||||
Ms: list[int]
|
||||
M_full: int
|
||||
a: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a_strides: torch.Tensor
|
||||
out: torch.Tensor
|
||||
c_strides: torch.Tensor
|
||||
per_tok_scales: torch.Tensor
|
||||
per_chan_scales: torch.Tensor
|
||||
w_refs: list[torch.Tensor]
|
||||
w_q_packed: torch.Tensor
|
||||
w_s_packed: torch.Tensor
|
||||
problem_sizes: torch.Tensor
|
||||
expert_offsets: torch.Tensor
|
||||
b_strides: torch.Tensor
|
||||
group_scale_strides: torch.Tensor
|
||||
|
||||
|
||||
def make_moe_test_setup(
|
||||
num_experts: int,
|
||||
K: int,
|
||||
N: int,
|
||||
*,
|
||||
alignment: int = ALIGNMENT,
|
||||
max_blocks: int = 64,
|
||||
device: str = "cuda",
|
||||
random_zero: bool = False,
|
||||
) -> MoETestSetup:
|
||||
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
|
||||
|
||||
assert K % GROUP_SIZE == 0
|
||||
# Token counts per expert (multiples of `alignment`).
|
||||
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
|
||||
|
||||
# set random experts to 0 tokens
|
||||
if random_zero and num_experts > 1:
|
||||
num_zero = max(1, num_experts // 8)
|
||||
zero_indices = random.sample(range(num_experts), k=num_zero)
|
||||
for idx in zero_indices:
|
||||
Ms[idx] = 0
|
||||
|
||||
M_full = sum(Ms)
|
||||
assert M_full > 0
|
||||
|
||||
# Activations.
|
||||
a = to_fp8(torch.randn((M_full, K), device=device))
|
||||
a_ref = a.to(torch.float32)
|
||||
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
|
||||
|
||||
# Output buffer.
|
||||
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
|
||||
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
|
||||
|
||||
# Channel/token scales.
|
||||
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
|
||||
per_chan_scales = torch.randn(
|
||||
(num_experts, N, 1), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Expert weights and scales.
|
||||
wtype = scalar_types.int4
|
||||
atype = stype = torch.float8_e4m3fn
|
||||
w_refs, w_qs, w_ss = [], [], []
|
||||
for _ in range(num_experts):
|
||||
b = to_fp8(torch.randn((K, N), device=device))
|
||||
w_ref, w_q, w_s, _ = cutlass_quantize(
|
||||
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
|
||||
)
|
||||
w_refs.append(w_ref)
|
||||
w_qs.append(w_q)
|
||||
w_ss.append(w_s)
|
||||
|
||||
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
|
||||
|
||||
problem_sizes = torch.tensor(
|
||||
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
expert_offsets = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int64),
|
||||
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
|
||||
]
|
||||
).to(device=device)
|
||||
|
||||
# B strides and group scale strides.
|
||||
b_strides = packed_layout
|
||||
group_scale_strides = torch.zeros(
|
||||
(num_experts, 2), dtype=torch.int64, device=device
|
||||
)
|
||||
group_scale_strides[:, 0] = N
|
||||
|
||||
return MoETestSetup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
Ms=Ms,
|
||||
M_full=M_full,
|
||||
a=a,
|
||||
a_ref=a_ref,
|
||||
a_strides=a_strides,
|
||||
out=out,
|
||||
c_strides=c_strides,
|
||||
per_tok_scales=per_tok_scales,
|
||||
per_chan_scales=per_chan_scales,
|
||||
w_refs=w_refs,
|
||||
w_q_packed=w_q_packed,
|
||||
w_s_packed=w_s_packed,
|
||||
problem_sizes=problem_sizes,
|
||||
expert_offsets=expert_offsets,
|
||||
b_strides=b_strides,
|
||||
group_scale_strides=group_scale_strides,
|
||||
)
|
||||
|
||||
|
||||
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
|
||||
"""Compute reference output using torch._scaled_mm per expert."""
|
||||
out_ref = torch.empty_like(setup.out)
|
||||
|
||||
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
|
||||
starts = setup.expert_offsets.cpu().tolist()
|
||||
|
||||
for i in range(setup.num_experts):
|
||||
start, end = starts[i], ends[i]
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
out_ref_i = torch._scaled_mm(
|
||||
setup.a_ref[start:end].to(torch.float8_e4m3fn),
|
||||
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
|
||||
setup.per_tok_scales[start:end], # (M, 1)
|
||||
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
out_ref[start:end] = out_ref_i
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("shape", TEST_SHAPES)
|
||||
@pytest.mark.parametrize("random_zero", [True, False])
|
||||
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
|
||||
num_experts, N, K = shape
|
||||
current_platform.seed_everything(42)
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
|
||||
)
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
setup.out,
|
||||
setup.a,
|
||||
setup.w_q_packed,
|
||||
setup.per_tok_scales,
|
||||
setup.per_chan_scales,
|
||||
setup.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
setup.expert_offsets,
|
||||
setup.problem_sizes,
|
||||
setup.a_strides,
|
||||
setup.b_strides,
|
||||
setup.c_strides,
|
||||
setup.group_scale_strides,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class W4A8MoELayer(torch.nn.Module):
|
||||
"""
|
||||
Minimal wrapper module to test cuda graphs
|
||||
"""
|
||||
|
||||
def __init__(self, setup: MoETestSetup):
|
||||
super().__init__()
|
||||
self.setup = setup
|
||||
|
||||
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||||
s = self.setup
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
s.out,
|
||||
a,
|
||||
s.w_q_packed,
|
||||
s.per_tok_scales,
|
||||
s.per_chan_scales,
|
||||
s.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
s.expert_offsets,
|
||||
s.problem_sizes,
|
||||
s.a_strides,
|
||||
s.b_strides,
|
||||
s.c_strides,
|
||||
s.group_scale_strides,
|
||||
)
|
||||
return s.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_w4a8_moe_mm_cuda_graph():
|
||||
current_platform.seed_everything(42)
|
||||
# Fixed config for CUDA graph test (single parameter point).
|
||||
num_experts = 8
|
||||
K = 512
|
||||
N = 2048
|
||||
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
max_blocks=32,
|
||||
)
|
||||
|
||||
# Construct model that calls the grouped GEMM kernel.
|
||||
model = W4A8MoELayer(setup)
|
||||
|
||||
# Build reference output once.
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
|
||||
# Capture and run the model in a CUDA graph.
|
||||
a_static = setup.a.clone() # static input tensor for graph replay
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out_static = model(a_static)
|
||||
|
||||
out_static.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
|
||||
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
139
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
convert_swizzled_to_linear,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("backend", ["cutlass", "trtllm"])
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
backend: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
if backend == "trtllm" and dtype == torch.float16:
|
||||
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
# So instead of needing to swizzle for cutlass as in modelopt.py,
|
||||
# we need to unswizzle for trtllm here.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
if backend == "trtllm":
|
||||
epilogue_tile_m = 128
|
||||
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
|
||||
|
||||
b_scale_interleaved = convert_swizzled_to_linear(
|
||||
b_scale_interleaved, n, k, block_size
|
||||
)
|
||||
b_scale_interleaved = (
|
||||
flashinfer.shuffle_matrix_sf_a(
|
||||
b_scale_interleaved.view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
.reshape(b_scale_interleaved.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp4_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
alpha,
|
||||
dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
72
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
72
tests/kernels/quantization/test_flashinfer_scaled_mm.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("autotune", [False, True])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp8_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
use_bias: bool,
|
||||
seed: int,
|
||||
device: str,
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, k = shape
|
||||
a = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b = torch.randn((n, k), dtype=dtype, device=device) / k
|
||||
|
||||
a_fp8, a_scale = ops.scaled_fp8_quant(a)
|
||||
b_fp8, b_scale = ops.scaled_fp8_quant(b)
|
||||
|
||||
expected_out = torch.mm(
|
||||
a_scale * a_fp8.to(dtype=torch.float32),
|
||||
b_scale * b_fp8.to(dtype=torch.float32).t(),
|
||||
).to(dtype=dtype)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn((n,), dtype=dtype, device=device)
|
||||
expected_out = expected_out + bias
|
||||
else:
|
||||
bias = None
|
||||
|
||||
import flashinfer
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp8_mm(
|
||||
a_fp8,
|
||||
b_fp8.t(),
|
||||
a_scale,
|
||||
b_scale,
|
||||
dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out, atol=1e-2, rtol=1e-2)
|
||||
120
tests/kernels/quantization/test_fp8_quant.py
Normal file
120
tests/kernels/quantization/test_fp8_quant.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (
|
||||
FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant,
|
||||
)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||
NUM_TOKENS = [1, 7, 4096]
|
||||
SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(
|
||||
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
|
||||
):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(input.shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub),
|
||||
)
|
||||
else:
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = (
|
||||
torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6
|
||||
) # avoid nans
|
||||
|
||||
scale_ub = (
|
||||
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
|
||||
)
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(
|
||||
x, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_scales, ops_scales)
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ref_scale, ops_scale)
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x)
|
||||
|
||||
|
||||
# Regression test for a case with large activations where an int32 index cannot
|
||||
# represent the number of elements.
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_fp8_quant_large(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, _ = ops.scaled_fp8_quant(x, scale)
|
||||
|
||||
# Minimize memory footprint in this test by freeing x and upconverting
|
||||
# the outputs in place. (torch.allclose does not support fp8)
|
||||
del x
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
ops_out = ops_out.to(dtype=dtype)
|
||||
|
||||
torch.testing.assert_close(ref_out, ops_out)
|
||||
166
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
166
tests/kernels/quantization/test_fp8_quant_group.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for QuantFP8 Group Quantization implementation."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,hidden_dim,group_size",
|
||||
[
|
||||
(16, 256, 32), # Small
|
||||
(64, 1024, 64), # Medium
|
||||
(128, 2048, 128), # Large
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(
|
||||
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
|
||||
) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
and verifies consistency between implementations.
|
||||
"""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
expected_num_groups = (hidden_dim + group_size - 1) // group_size
|
||||
is_divisible = hidden_dim % group_size == 0
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
assert x_quant_native.shape == x.shape
|
||||
assert scales_native.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
assert scales_col.stride(1) == batch_size
|
||||
|
||||
# Test column-major scales consistency
|
||||
torch.testing.assert_close(scales_col, scales_native, rtol=1e-9, atol=1e-8)
|
||||
|
||||
# 3. Test CUDA implementation (only for divisible dimensions)
|
||||
if is_divisible:
|
||||
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
|
||||
assert x_quant_cuda.shape == x.shape
|
||||
assert scales_cuda.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# Verify CUDA/native consistency
|
||||
torch.testing.assert_close(scales_cuda, scales_native, rtol=2e-7, atol=2e-8)
|
||||
|
||||
# Quantized values should mostly match
|
||||
diff_count = (x_quant_cuda != x_quant_native).sum().item()
|
||||
diff_ratio = diff_count / x_quant_cuda.numel()
|
||||
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
group_size = 64
|
||||
|
||||
# Test with 3D input
|
||||
batch1, batch2, hidden_dim = 4, 8, 1024
|
||||
x_3d = (
|
||||
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
|
||||
* 8
|
||||
)
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test with 4D input
|
||||
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||
x_4d = (
|
||||
torch.randn(
|
||||
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
* 8
|
||||
)
|
||||
|
||||
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
|
||||
assert x_quant_4d.shape == x_4d.shape
|
||||
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
|
||||
|
||||
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
batch_size = 16
|
||||
group_size = 64
|
||||
|
||||
# Test with single group (group_size >= hidden_dim)
|
||||
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(
|
||||
static=False, group_shape=group_shape, column_major_scales=False
|
||||
)
|
||||
|
||||
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
|
||||
assert x_quant_small.shape == x_small.shape
|
||||
assert scales_small.shape == (batch_size, 1)
|
||||
|
||||
# Test with zero inputs
|
||||
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
|
||||
assert x_quant_zero.shape == x_zero.shape
|
||||
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
|
||||
|
||||
# Test very large values
|
||||
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
|
||||
assert x_quant_large.shape == x_large.shape
|
||||
# FP8 max is typically 448 or 224, so scales should be > 1
|
||||
assert (scales_large > 1.0).all(), "Large values should have scales > 1"
|
||||
54
tests/kernels/quantization/test_ggml.py
Normal file
54
tests/kernels/quantization/test_ggml.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_type", [12])
|
||||
def test_ggml_opcheck(quant_type):
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
|
||||
shape = [256, 1152]
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
m = qweight.shape[0]
|
||||
n = qweight.shape[1] // type_size * block_size
|
||||
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16))
|
||||
|
||||
x = torch.rand((m, 512), device="cuda", dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(
|
||||
torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])
|
||||
)
|
||||
|
||||
shape = [256, 1024, 336]
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device="cuda", dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device="cuda")
|
||||
expert_ids = torch.randint(0, 256, (194,), device="cuda")
|
||||
num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda")
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8,
|
||||
(
|
||||
x,
|
||||
qweight,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
quant_type,
|
||||
qweight.shape[0],
|
||||
1,
|
||||
x.shape[0],
|
||||
),
|
||||
)
|
||||
|
||||
topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8_vec,
|
||||
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]),
|
||||
)
|
||||
207
tests/kernels/quantization/test_gguf.py
Normal file
207
tests/kernels/quantization/test_gguf.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
|
||||
|
||||
|
||||
def get_gguf_sample_tensors(
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
def get_gguf_MoE_tensors(
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE_MOE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
HIDDEN_SIZES = [256, 1024]
|
||||
NUM_TOKENS = [7, 2050] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
QUANT_TYPES = [
|
||||
# i-matrix
|
||||
GGMLQuantizationType.IQ1_M,
|
||||
GGMLQuantizationType.IQ1_S,
|
||||
GGMLQuantizationType.IQ2_S,
|
||||
GGMLQuantizationType.IQ2_XS,
|
||||
GGMLQuantizationType.IQ3_S,
|
||||
GGMLQuantizationType.IQ3_XXS,
|
||||
GGMLQuantizationType.IQ4_NL,
|
||||
GGMLQuantizationType.IQ4_XS,
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quantization
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(
|
||||
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
|
||||
):
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
for tensor in tensors:
|
||||
shape_str = tensor.name.split("_")[-1]
|
||||
shape = map(int, shape_str.split("x"))
|
||||
|
||||
ref_output = torch.tensor(
|
||||
dequantize(tensor.data, quant_type), device="cuda"
|
||||
).to(dtype)
|
||||
output = ops.ggml_dequantize(
|
||||
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_mmq(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
|
||||
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
|
||||
# test matrix has inputs centered around 0 and lower precision from
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(
|
||||
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("top_k", [4, 8])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_moe(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
top_k: int,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
H, E = 1024, 256
|
||||
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(
|
||||
0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
w13 = tensors[0]
|
||||
w2 = tensors[1]
|
||||
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype)
|
||||
|
||||
output = _fused_moe_gguf(
|
||||
x,
|
||||
torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data, device="cuda"),
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type,
|
||||
quant_type,
|
||||
"silu",
|
||||
)
|
||||
|
||||
ref_output = fused_experts(
|
||||
x, w13_dequant, w2_dequant, topk_weights, topk_ids
|
||||
).reshape(output.shape)
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
35
tests/kernels/quantization/test_gptq.py
Normal file
35
tests/kernels/quantization/test_gptq.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_gptq_shuffle_opcheck():
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32
|
||||
)
|
||||
perm = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
|
||||
|
||||
|
||||
def test_gptq_gemm_opcheck():
|
||||
a = torch.rand((240, 4096), device="cuda", dtype=torch.float16)
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32
|
||||
)
|
||||
zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16)
|
||||
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
# Test both GPTQv1 and GPTQv2 format
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit)
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit)
|
||||
)
|
||||
33
tests/kernels/quantization/test_hadacore.py
Normal file
33
tests/kernels/quantization/test_hadacore.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform import deterministic_hadamard_matrix
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require hadacore_transform, not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("hidden_dim", [2**n for n in range(10)])
|
||||
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
|
||||
x = torch.eye(hidden_dim, dtype=dtype, device=device)
|
||||
hadamard = deterministic_hadamard_matrix(
|
||||
hidden_dim, dtype=torch.float64, device="cuda"
|
||||
) / math.sqrt(hidden_dim)
|
||||
|
||||
y = ops.hadacore_transform(x.clone())
|
||||
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
|
||||
assert torch.allclose(y, y_true)
|
||||
|
||||
y = ops.hadacore_transform(y)
|
||||
assert torch.allclose(y, x)
|
||||
155
tests/kernels/quantization/test_int8_kernel.py
Normal file
155
tests/kernels/quantization/test_int8_kernel.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
|
||||
)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
195
tests/kernels/quantization/test_int8_quant.py
Normal file
195
tests/kernels/quantization/test_int8_quant.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [17, 1024, 1025, 1026, 5137, 8193]
|
||||
NUM_TOKENS = [1, 7, 4096]
|
||||
SEEDS = [0]
|
||||
SCALE = [0.1, 2.1]
|
||||
|
||||
|
||||
def opcheck_int8_quant_static(output, input, scale, azp=None):
|
||||
if azp is None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
if symmetric:
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
azp = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
# reference
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
|
||||
# kernel
|
||||
ops_out, ops_scales, _ = scaled_int8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ops_scales, ref_scales)
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_dynamic(ops_out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
|
||||
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
|
||||
|
||||
# calculate scale and azp, and adjust the range
|
||||
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32)
|
||||
|
||||
torch_out = (
|
||||
((x / scales).round() + azps)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max
|
||||
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if not torch.allclose(scales_out, scales):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
torch.testing.assert_close(scales_out, scales)
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0)
|
||||
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
|
||||
torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_dynamic(ops_out, x, False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (
|
||||
(x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_static(out2, x, scale_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@pytest.mark.parametrize("azp", [-255, 54])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
scale: float,
|
||||
azp: int,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
out1 = (
|
||||
((x / scale).round() + azp)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_static(out2, x, scale_arg, azp_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_max", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
# Test that the saturating cast works correctly for values near i32 max/min
|
||||
|
||||
from numpy import inf, nextafter
|
||||
|
||||
int32_traits = torch.iinfo(torch.int32)
|
||||
val = float(int32_traits.max if is_max else int32_traits.min)
|
||||
|
||||
x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]]
|
||||
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
|
||||
|
||||
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
|
||||
# where cast<T> is a saturating cast to type T.
|
||||
# Scale is set to 1.0 so that the input values are the ones that are cast.
|
||||
# AZP is set to 0 to make sure the int8 saturating cast is tested as well.
|
||||
scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda")
|
||||
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
val_i8 = int8_traits.max if is_max else int8_traits.min
|
||||
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
|
||||
|
||||
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
|
||||
torch.testing.assert_close(expected, out, atol=0, rtol=0)
|
||||
447
tests/kernels/quantization/test_machete_mm.py
Normal file
447
tests/kernels/quantization/test_machete_mm.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the machete kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_machete_mm.py`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
query_machete_supported_group_sizes,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require machete_prepack_B, not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4224, 4160),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
group_zero_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: torch.Tensor | None
|
||||
w_g_zp: torch.Tensor | None
|
||||
w_ch_s: torch.Tensor | None
|
||||
w_tok_s: torch.Tensor | None
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
for a_type in [torch.float16, torch.bfloat16]
|
||||
),
|
||||
# AWQ style
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=a_type,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]
|
||||
),
|
||||
# # QQQ style
|
||||
# *(TypeConfig(act_type=torch.int8,
|
||||
# weight_type=scalar_types.uint4b8,
|
||||
# output_type=torch.float16,
|
||||
# group_scale_type=group_scale_type,
|
||||
# group_zero_type=None,
|
||||
# channel_scale_type=torch.float,
|
||||
# token_scale_type=torch.float)
|
||||
# for group_scale_type in [None, torch.float16]),
|
||||
# *(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
# weight_type=scalar_types.uint4b8,
|
||||
# output_type=torch.float16,
|
||||
# group_scale_type=group_scale_type,
|
||||
# group_zero_type=None,
|
||||
# channel_scale_type=torch.float,
|
||||
# token_scale_type=torch.float)
|
||||
# for group_scale_type in [None, torch.float16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||
if dtype.is_floating_point:
|
||||
return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
|
||||
else:
|
||||
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
|
||||
return group_size is None or group_size == -1 or shape[2] % group_size == 0
|
||||
|
||||
|
||||
def machete_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True,
|
||||
)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: int | None,
|
||||
subset_stride_factor: int | None = None,
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
|
||||
print(
|
||||
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
||||
)
|
||||
|
||||
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
|
||||
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
|
||||
|
||||
if factor > 1:
|
||||
a = a[0:m, 0:k]
|
||||
w = w[0:k, 0:n]
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype,
|
||||
w,
|
||||
types.weight_type,
|
||||
types.group_scale_type,
|
||||
group_size,
|
||||
types.group_zero_type is not None,
|
||||
)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = (
|
||||
None
|
||||
if types.channel_scale_type is None
|
||||
else rand_data((n,), types.channel_scale_type)
|
||||
)
|
||||
w_tok_s = (
|
||||
None
|
||||
if types.token_scale_type is None
|
||||
else rand_data((m,), types.token_scale_type)
|
||||
)
|
||||
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
|
||||
|
||||
# None stype means scales use the same dtype as a
|
||||
def machete_mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: int | None = None,
|
||||
schedule: str | None = None,
|
||||
):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
|
||||
if tensors.w_ch_s is not None:
|
||||
output_ref = (
|
||||
output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0)
|
||||
).to(output_ref_type)
|
||||
if tensors.w_tok_s is not None:
|
||||
output_ref = (
|
||||
output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1)
|
||||
).to(output_ref_type)
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_type=types.weight_type,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_zeros=tensors.w_g_zp,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
out_type=types.output_type,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = (
|
||||
1
|
||||
if tensors.w_g_zp is not None
|
||||
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
|
||||
)
|
||||
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = query_machete_supported_group_sizes(types.act_type)
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
print(f"MNK = {shape}")
|
||||
for schedule in ops.machete_supported_schedules(
|
||||
types.act_type,
|
||||
types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_scale_type,
|
||||
out_type=types.output_type,
|
||||
):
|
||||
print(f"Testing schedule {schedule}")
|
||||
machete_mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: list[int | None] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = query_machete_supported_group_sizes(types.act_type)
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
machete_mm_test_helper(types, tensors, group_size)
|
||||
|
||||
|
||||
# Test working on other devices
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_machete_devices(device: str):
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(
|
||||
act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
|
||||
|
||||
for field in fields(Tensors):
|
||||
tensor = getattr(tensors, field.name)
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
setattr(tensors, field.name, tensor.to(device))
|
||||
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
def test_machete_subset():
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(
|
||||
act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
|
||||
tensors = create_test_tensors(
|
||||
(512, 4096, 4096), type_config, group_size, subset_stride_factor=2
|
||||
)
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class MacheteLayer(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.machete_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
def test_machete_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = rand_data((m, k), torch.float16)
|
||||
b = rand_data((k, n), torch.float16)
|
||||
wtype = scalar_types.uint4b8
|
||||
stype = torch.float16
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, b, wtype, stype, group_size, zero_points
|
||||
)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a machete kernel
|
||||
model = MacheteLayer(
|
||||
b_q=w_q_packed,
|
||||
b_type=wtype,
|
||||
b_group_scales=w_s,
|
||||
b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
||||
812
tests/kernels/quantization/test_marlin_gemm.py
Normal file
812
tests/kernels/quantization/test_marlin_gemm.py
Normal file
@@ -0,0 +1,812 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the marlin kernel.
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
||||
"""
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
get_weight_perm,
|
||||
marlin_quantize,
|
||||
marlin_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack,
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"These tests require gptq_marlin_repack,"
|
||||
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
|
||||
"or gptq_marlin_gemm which are not supported on ROCm.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
USE_ATOMIC_ADD_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [True]
|
||||
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
|
||||
MARLIN_REPACK_NK_FACTORS = [
|
||||
(4, 8),
|
||||
(7, 5),
|
||||
(13, 11),
|
||||
]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(26, 37, 13),
|
||||
(257, 13, 11),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT8
|
||||
{
|
||||
"b_type": scalar_types.uint8b128,
|
||||
"support_act_order": True,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
||||
# NVFP4
|
||||
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": [scalar_types.bfloat16],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": [scalar_types.int8],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4b8,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.uint4,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": [scalar_types.float8_e4m3fn],
|
||||
"b_type": scalar_types.float4_e2m1f,
|
||||
"c_type": [scalar_types.bfloat16],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_without_zp():
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
|
||||
|
||||
torch_res = torch.where(
|
||||
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
|
||||
)
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
def test_marlin_int4_fp8_preprocess_awq():
|
||||
group_size = 128
|
||||
|
||||
qweight_unpacked = torch.randint(
|
||||
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qzeros_unpacked = torch.randint(
|
||||
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
||||
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
||||
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
|
||||
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
|
||||
|
||||
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
|
||||
|
||||
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
|
||||
torch_res = qweight_unpacked - repeated_zp
|
||||
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
|
||||
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
||||
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
||||
|
||||
assert (cuda_res == torch_res).all()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
|
||||
):
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
group_size = 128
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
if is_a_8bit:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
||||
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
|
||||
n_factor, k_factor = nk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
group_size = 128
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights(
|
||||
b_weight, quant_type, group_size, zero_points=True
|
||||
)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
||||
)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
def marlin_generate_valid_test_cases():
|
||||
all_combinations = itertools.product(
|
||||
DENSE_MARLIN_QUANT_TEST_CONFIGS,
|
||||
MNK_FACTORS,
|
||||
MARLIN_N_CHUNKS,
|
||||
MARLIN_K_CHUNKS,
|
||||
ACT_ORDER_OPTS,
|
||||
K_FULL_OPTS,
|
||||
USE_ATOMIC_ADD_OPTS,
|
||||
USE_FP32_REDUCE_OPTS,
|
||||
)
|
||||
|
||||
def is_invalid(
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
if use_atomic_add:
|
||||
if use_fp32_reduce:
|
||||
return False
|
||||
if (
|
||||
c_type == scalar_types.bfloat16
|
||||
and torch.cuda.get_device_capability()[0] < 9
|
||||
):
|
||||
return False
|
||||
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
if group_size > 0 and size_k % group_size != 0:
|
||||
return False
|
||||
|
||||
if act_order and group_size in [-1, size_k]:
|
||||
return False
|
||||
if group_size == size_k:
|
||||
return False
|
||||
if not act_order and is_k_full:
|
||||
return False
|
||||
|
||||
return a_type.size_bits < 16 or a_type is c_type
|
||||
|
||||
cases = []
|
||||
for case in all_combinations:
|
||||
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
|
||||
size_m = mnk_factors[0]
|
||||
size_n = mnk_factors[1] * n_chunk
|
||||
size_k = mnk_factors[2] * k_chunk
|
||||
|
||||
if act_order and not quant_test_config.get("support_act_order", False):
|
||||
continue
|
||||
|
||||
f16_types = [scalar_types.float16, scalar_types.bfloat16]
|
||||
inner_combinations = itertools.product(
|
||||
quant_test_config.get("a_type", f16_types),
|
||||
[quant_test_config["b_type"]],
|
||||
quant_test_config.get("c_type", f16_types),
|
||||
quant_test_config["group_blocks"],
|
||||
)
|
||||
|
||||
for sub_case in inner_combinations:
|
||||
if (
|
||||
sub_case[0] == scalar_types.float8_e4m3fn
|
||||
and current_platform.get_device_capability() not in [89, 120]
|
||||
):
|
||||
continue
|
||||
args = sub_case + (size_m, size_n, size_k) + case[4:]
|
||||
if is_invalid(*args):
|
||||
cases.append(args)
|
||||
return cases
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"a_type, b_type, c_type, group_blocks,"
|
||||
"size_m, size_n, size_k, act_order, is_k_full,"
|
||||
"use_atomic_add, use_fp32_reduce"
|
||||
),
|
||||
marlin_generate_valid_test_cases(),
|
||||
)
|
||||
def test_gptq_marlin_gemm(
|
||||
a_type,
|
||||
b_type,
|
||||
c_type,
|
||||
group_blocks,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
||||
|
||||
if c_type == scalar_types.float16:
|
||||
dtype = torch.float16
|
||||
elif c_type == scalar_types.bfloat16:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise RuntimeError("unsupported c_type")
|
||||
|
||||
if a_type == scalar_types.int8:
|
||||
a_dtype = torch.int8
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_dtype = torch.float8_e4m3fn
|
||||
else:
|
||||
a_dtype = dtype
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
if b_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
elif b_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
elif has_zp:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, b_type, group_size, input_dtype=a_dtype
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
|
||||
)
|
||||
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
if a_type == scalar_types.int8:
|
||||
a_input, a_scales = per_token_quant_int8(a_input)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
|
||||
if group_size != -1:
|
||||
a_scales = a_scales / 4096 * marlin_s.max()
|
||||
a_scales = a_scales.float()
|
||||
marlin_s = marlin_s / marlin_s.max() * 4096
|
||||
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
|
||||
elif a_type == scalar_types.float8_e4m3fn:
|
||||
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
|
||||
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
||||
a_input_ref = a_input_ref.to(dtype)
|
||||
else:
|
||||
assert a_type.size_bits == 16
|
||||
a_input_ref = a_input
|
||||
a_scales = None
|
||||
|
||||
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
output,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
a_scales,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
b_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input_ref, w_ref)
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
):
|
||||
return ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
|
||||
workspace_24 = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_24_gemm,
|
||||
(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_hqq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
quant_type = scalar_types.uint4
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
dev = a_input.device
|
||||
|
||||
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
|
||||
scale = rand_data((size_n, size_k // group_size))
|
||||
zero = rand_data((size_n, size_k // group_size))
|
||||
|
||||
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
|
||||
dev
|
||||
)
|
||||
marlin_s = marlin_permute_scales(
|
||||
scale.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
marlin_zp = marlin_permute_scales(
|
||||
zero.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
workspace = marlin_make_workspace_new(b_weight.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_w_q,
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[0],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=True,
|
||||
)
|
||||
|
||||
b_flat = b_weight.reshape(-1, group_size)
|
||||
zp_flat = zero.reshape(-1, 1)
|
||||
s_flat = scale.reshape(-1, 1)
|
||||
dequant = (b_flat - zp_flat) * s_flat
|
||||
|
||||
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
size_m, size_k, size_n = 32, 1024, 2048
|
||||
big_m = size_m * 2
|
||||
big_k = size_k * 2
|
||||
|
||||
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size_m", [1, 256])
|
||||
def test_marlin_gemm_with_bias(size_m):
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
size_k, size_n = 1024, 2048
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
b_bias = rand_data((size_n,)) * 10
|
||||
|
||||
marlin_bias = marlin_permute_bias(b_bias)
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
marlin_bias,
|
||||
marlin_s,
|
||||
None,
|
||||
None,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
303
tests/kernels/quantization/test_mxfp4_qutlass.py
Normal file
303
tests/kernels/quantization/test_mxfp4_qutlass.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
|
||||
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for these tests.", allow_module_level=True)
|
||||
|
||||
if not (
|
||||
current_platform.has_device_capability(100)
|
||||
or current_platform.has_device_capability(120)
|
||||
):
|
||||
pytest.skip(
|
||||
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
# ----- Helpers -----
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _rtne_fp4(x: torch.Tensor):
|
||||
device = x.device
|
||||
grid = torch.tensor(
|
||||
[
|
||||
-6.0,
|
||||
-4.0,
|
||||
-3.0,
|
||||
-2.0,
|
||||
-1.5,
|
||||
-1.0,
|
||||
-0.5,
|
||||
-0.0,
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
grid_int = torch.tensor(
|
||||
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
inds = torch.bucketize(x, grid)
|
||||
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
|
||||
g_lo, g_hi = grid[lo], grid[hi]
|
||||
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
|
||||
y = torch.where(pick_hi, g_hi, g_lo)
|
||||
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
|
||||
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
|
||||
return y, y_int_packed
|
||||
|
||||
|
||||
def _dq_fp4(x_e2m1: torch.Tensor, x_e8m0: torch.Tensor, alpha: float):
|
||||
device = x_e2m1.device
|
||||
|
||||
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
|
||||
x_e2m1_unpacked = torch.stack(
|
||||
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
grid_dq = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
],
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
x_fp4_dq = grid_dq[x_e2m1_unpacked]
|
||||
scales_dq = x_e8m0.to(torch.float64)
|
||||
|
||||
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 32)) * scales_dq[..., None]).flatten(
|
||||
start_dim=-2
|
||||
) / alpha
|
||||
return x_dq, x_fp4_dq, scales_dq
|
||||
|
||||
|
||||
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
|
||||
clip_mask_unpacked_dq = torch.zeros(
|
||||
*clip_mask.shape[:-1],
|
||||
clip_mask.size(-1) * 8,
|
||||
dtype=torch.bool,
|
||||
device=clip_mask.device,
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
|
||||
return clip_mask_unpacked_dq
|
||||
|
||||
|
||||
def _forward_quantize_ref(
|
||||
x: torch.Tensor, h: torch.Tensor, rot_size: int, quest: bool = True
|
||||
):
|
||||
device = x.device
|
||||
xh_ref64 = (
|
||||
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
|
||||
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
if quest:
|
||||
scales_ref64_ = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).std(dim=-1, correction=0)
|
||||
* (2.92247856 / 6.0)
|
||||
+ 1e-8
|
||||
)
|
||||
else:
|
||||
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 32)).abs().amax(dim=-1)
|
||||
scales_ref64_ = abs_max + 1e-8
|
||||
|
||||
xh_e8m0_ref = scales_ref64_.log2().floor().exp2().to(dtype=torch.float8_e8m0fnu)
|
||||
scales_ref64 = xh_e8m0_ref.to(dtype=torch.float64)
|
||||
|
||||
xh_scaled_ref64 = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 32)) / scales_ref64[..., None]
|
||||
).flatten(start_dim=-2)
|
||||
if not quest:
|
||||
xh_scaled_ref64 *= 3
|
||||
|
||||
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
|
||||
clip_mask_ref = torch.zeros(
|
||||
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
|
||||
|
||||
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
|
||||
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(
|
||||
xh_e2m1_ref, xh_e8m0_ref, alpha=1.0 if quest else 3.0
|
||||
)
|
||||
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
|
||||
|
||||
assert xh_fp4_dq.equal(xh_fp4_ref)
|
||||
assert scales_dq.equal(scales_ref64)
|
||||
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
|
||||
|
||||
return (
|
||||
xh_dq,
|
||||
clip_mask_unpacked_ref,
|
||||
(xh_e2m1_ref, xh_e8m0_ref, clip_mask_ref),
|
||||
)
|
||||
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
DEVICE = torch.device("cuda:0")
|
||||
|
||||
ROT_SIZES = [32, 64, 128]
|
||||
SEEDS = [0]
|
||||
BATCHES = [1, 16]
|
||||
|
||||
LLAMA_MODELS = {
|
||||
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
|
||||
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
|
||||
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
|
||||
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _seed_each_test():
|
||||
current_platform.seed_everything(0)
|
||||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization_absmax(rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=False)
|
||||
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="abs_max")
|
||||
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=3.0)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
|
||||
|
||||
m, n, k = 1, 504, 4096
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="abs_max")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="abs_max")
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization_quest(rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size, quest=True)
|
||||
xh_e2m1, xh_e8m0 = fusedQuantizeMx(x, h, method="quest")
|
||||
xh_e8m0 = xh_e8m0.reshape(2, 4096, 4096 // 32)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e8m0, alpha=1.0)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-4
|
||||
|
||||
m, n, k = 504, 504, 2048
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
|
||||
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("batch", [1, 16])
|
||||
@pytest.mark.parametrize("had_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_llama_shapes(model: str, layer_idx: int, batch: int, had_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
m = batch
|
||||
k, n = LLAMA_MODELS[model][layer_idx]
|
||||
|
||||
h = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
a = torch.rand(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.rand(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e8m0 = fusedQuantizeMx(a, h, method="quest")
|
||||
b_e2m1, b_e8m0 = fusedQuantizeMx(b, h, method="quest")
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e8m0[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e8m0[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e8m0, backend="triton")
|
||||
b_scale_block = to_blocked(b_e8m0, backend="triton")
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = matmul_mxf4_bf16_tn(a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
174
tests/kernels/quantization/test_nvfp4_quant.py
Normal file
174
tests/kernels/quantization/test_nvfp4_quant.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
PAD_SHAPES = [
|
||||
(90, 64),
|
||||
(150, 64),
|
||||
(128, 48),
|
||||
(128, 80),
|
||||
(150, 80),
|
||||
(90, 48),
|
||||
(90, 128),
|
||||
(150, 128),
|
||||
(150, 48),
|
||||
(90, 80),
|
||||
]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# E2M1 to float
|
||||
# 0111 -> 6
|
||||
# 0110 -> 4
|
||||
# 0101 -> 3
|
||||
# 0100 -> 2
|
||||
# 0011 -> 1.5
|
||||
# 0010 -> 1
|
||||
# 0001 -> 0.5
|
||||
# 0000 -> 0
|
||||
E2M1_TO_FLOAT32 = [
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def cast_from_fp4(x, m, n):
|
||||
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
|
||||
v_2nd = x & 0xF
|
||||
v_1st = (x >> 4) & 0xF
|
||||
c = torch.stack((v_2nd, v_1st), dim=-1)
|
||||
out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()])
|
||||
out = out.reshape(m, n).to(torch.float32)
|
||||
return out
|
||||
|
||||
|
||||
def cast_to_fp4(x):
|
||||
sign = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
x[(x >= 0.0) & (x <= 0.25)] = 0.0
|
||||
x[(x > 0.25) & (x < 0.75)] = 0.5
|
||||
x[(x >= 0.75) & (x <= 1.25)] = 1.0
|
||||
x[(x > 1.25) & (x < 1.75)] = 1.5
|
||||
x[(x >= 1.75) & (x <= 2.5)] = 2.0
|
||||
x[(x > 2.5) & (x < 3.5)] = 3.0
|
||||
x[(x >= 3.5) & (x <= 5.0)] = 4.0
|
||||
x[x > 5.0] = 6.0
|
||||
return x * sign
|
||||
|
||||
|
||||
def get_reciprocal(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||
elif isinstance(x, (float, int)):
|
||||
return 0.0 if x == 0 else 1.0 / x
|
||||
else:
|
||||
raise TypeError("Input must be a float, int, or a torch.Tensor.")
|
||||
|
||||
|
||||
def ref_nvfp4_quant(x, global_scale):
|
||||
assert global_scale.dtype == torch.float32
|
||||
assert x.ndim == 2
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
|
||||
scaled_x = x.to(torch.float32) * output_scale
|
||||
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||
|
||||
|
||||
def recover_swizzled_scales(scale, m, n):
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // BLOCK_SIZE
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
# Recover the swizzled scaling factor to linear layout
|
||||
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
|
||||
return result[:m, :scale_n]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
dtype = torch.float16
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
268
tests/kernels/quantization/test_nvfp4_qutlass.py
Normal file
268
tests/kernels/quantization/test_nvfp4_qutlass.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
|
||||
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||
from vllm._custom_ops import fusedQuantizeNv
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for these tests.", allow_module_level=True)
|
||||
|
||||
if not (
|
||||
current_platform.has_device_capability(100)
|
||||
or current_platform.has_device_capability(120)
|
||||
):
|
||||
pytest.skip(
|
||||
reason="Tests require compute capability 10.0 (100) or 12.0 (120).",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
# ----- Helpers -----
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _rtne_fp4(x: torch.Tensor):
|
||||
device = x.device
|
||||
grid = torch.tensor(
|
||||
[
|
||||
-6.0,
|
||||
-4.0,
|
||||
-3.0,
|
||||
-2.0,
|
||||
-1.5,
|
||||
-1.0,
|
||||
-0.5,
|
||||
-0.0,
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
grid_int = torch.tensor(
|
||||
[-1, -2, -3, -4, -5, -6, -7, -8, 0, 1, 2, 3, 4, 5, 6, 7],
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
inds = torch.bucketize(x, grid)
|
||||
lo, hi = (inds - 1).clamp(min=0, max=15), inds.clamp(min=0, max=15)
|
||||
g_lo, g_hi = grid[lo], grid[hi]
|
||||
pick_hi = (g_hi - x < x - g_lo) | (g_hi - x == x - g_lo) & (grid_int[hi] % 2 == 0)
|
||||
y = torch.where(pick_hi, g_hi, g_lo)
|
||||
y_int = torch.where(pick_hi, grid_int[hi], grid_int[lo])
|
||||
y_int_packed = (y_int[..., 1::2] & 0xF) << 4 | y_int[..., ::2] & 0xF
|
||||
return y, y_int_packed
|
||||
|
||||
|
||||
def _dq_fp4(x_e2m1: torch.Tensor, x_e4m3: torch.Tensor, alpha: float):
|
||||
device = x_e2m1.device
|
||||
|
||||
x_e2m1_i32 = x_e2m1.view(dtype=torch.uint8).to(dtype=torch.int32)
|
||||
x_e2m1_unpacked = torch.stack(
|
||||
[x_e2m1_i32 & 0xF, (x_e2m1_i32 >> 4) & 0xF], dim=-1
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
grid_dq = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
-0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
],
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
x_fp4_dq = grid_dq[x_e2m1_unpacked]
|
||||
|
||||
scales_dq = x_e4m3.to(torch.float64)
|
||||
x_dq = (x_fp4_dq.unflatten(dim=-1, sizes=(-1, 16)) * scales_dq[..., None]).flatten(
|
||||
start_dim=-2
|
||||
) / alpha # * (4. / 3.)
|
||||
return x_dq, x_fp4_dq, scales_dq
|
||||
|
||||
|
||||
def _unpack_mask(clip_mask: torch.Tensor) -> torch.Tensor:
|
||||
clip_mask_unpacked_dq = torch.zeros(
|
||||
*clip_mask.shape[:-1],
|
||||
clip_mask.size(-1) * 8,
|
||||
dtype=torch.bool,
|
||||
device=clip_mask.device,
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_unpacked_dq[..., i::8] = (clip_mask >> i) & 1
|
||||
return clip_mask_unpacked_dq
|
||||
|
||||
|
||||
def _forward_quantize_ref(x: torch.Tensor, h: torch.Tensor, rot_size: int):
|
||||
device = x.device
|
||||
|
||||
xh_ref64 = (
|
||||
x.unflatten(dim=-1, sizes=(-1, rot_size)).to(dtype=torch.float64)
|
||||
@ h.reshape(rot_size, rot_size).to(dtype=torch.float64)
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
abs_max = xh_ref64.unflatten(dim=-1, sizes=(-1, 16)).abs().amax(dim=-1)
|
||||
scales_ref64_ = abs_max + 1e-8
|
||||
|
||||
xh_e4m3_ref = scales_ref64_.to(dtype=torch.float8_e4m3fn)
|
||||
scales_ref64 = xh_e4m3_ref.to(dtype=torch.float64)
|
||||
xh_scaled_ref64 = (
|
||||
xh_ref64.unflatten(dim=-1, sizes=(-1, 16)) / scales_ref64[..., None]
|
||||
).flatten(start_dim=-2)
|
||||
|
||||
xh_scaled_ref64 *= 6.0
|
||||
|
||||
clip_mask_unpacked_ref = xh_scaled_ref64.abs() < 6.0
|
||||
clip_mask_ref = torch.zeros(
|
||||
*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=device
|
||||
)
|
||||
for i in range(8):
|
||||
clip_mask_ref |= clip_mask_unpacked_ref[..., i::8].to(dtype=torch.uint8) << i
|
||||
|
||||
xh_fp4_ref, xh_e2m1_ref = _rtne_fp4(xh_scaled_ref64)
|
||||
xh_dq, xh_fp4_dq, scales_dq = _dq_fp4(xh_e2m1_ref, xh_e4m3_ref, 6.0)
|
||||
clip_mask_unpacked_dq = _unpack_mask(clip_mask_ref)
|
||||
|
||||
assert xh_fp4_dq.equal(xh_fp4_ref)
|
||||
assert scales_dq.equal(scales_ref64)
|
||||
assert clip_mask_unpacked_dq.equal(clip_mask_unpacked_ref)
|
||||
|
||||
return (
|
||||
xh_dq,
|
||||
clip_mask_unpacked_ref,
|
||||
(xh_e2m1_ref, xh_e4m3_ref, clip_mask_ref),
|
||||
)
|
||||
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
DEVICE = torch.device("cuda:0")
|
||||
ROT_SIZES = [16, 32, 64, 128]
|
||||
GLOBAL_SCALES = [6.0]
|
||||
|
||||
LLAMA_MODELS = {
|
||||
"7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],
|
||||
"13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],
|
||||
"33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],
|
||||
"70B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _seed_each_test():
|
||||
current_platform.seed_everything(0)
|
||||
np.random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@pytest.mark.parametrize("global_scale_value", GLOBAL_SCALES)
|
||||
@torch.inference_mode()
|
||||
def test_fused_quantization(rot_size: int, global_scale_value: float):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
x = torch.randn(2, 4096, 4096, dtype=dtype, device=device) * 25.0
|
||||
global_scale = torch.tensor([global_scale_value], device=device)
|
||||
|
||||
xh_dq_ref, _, _ = _forward_quantize_ref(x, h, rot_size)
|
||||
xh_e2m1, xh_e4m3 = fusedQuantizeNv(x, h, global_scale)
|
||||
xh_e4m3 = xh_e4m3.reshape(2, 4096, 4096 // 16)
|
||||
xh_dq, *_ = _dq_fp4(xh_e2m1, xh_e4m3, alpha=global_scale_value)
|
||||
|
||||
torch.testing.assert_close(xh_dq, xh_dq_ref, rtol=0.34, atol=100)
|
||||
assert (xh_dq != xh_dq_ref).float().mean() <= 1e-1
|
||||
|
||||
m, n, k = 504, 4096 * 2, 4096
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
|
||||
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
|
||||
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
|
||||
)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", list(LLAMA_MODELS.keys()))
|
||||
@pytest.mark.parametrize("layer_idx", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("batch", [1, 16])
|
||||
@pytest.mark.parametrize("rot_size", ROT_SIZES)
|
||||
@torch.inference_mode()
|
||||
def test_llama_shapes(model: str, layer_idx: int, batch: int, rot_size: int):
|
||||
dtype, device = DTYPE, DEVICE
|
||||
m = batch
|
||||
k, n = LLAMA_MODELS[model][layer_idx]
|
||||
|
||||
h = get_hadamard_matrix(rot_size, dtype, device)
|
||||
|
||||
a = torch.randn(m, k, dtype=dtype, device=device) * 25.0
|
||||
b = torch.randn(n, k, dtype=dtype, device=device) * 25.0
|
||||
|
||||
global_scale = torch.tensor([1.0], device=device)
|
||||
|
||||
a_e2m1, a_e4m3 = fusedQuantizeNv(a, h, global_scale)
|
||||
b_e2m1, b_e4m3 = fusedQuantizeNv(b, h, global_scale)
|
||||
|
||||
a_dq, *_ = _dq_fp4(a_e2m1, a_e4m3[:m, :k], alpha=1.0)
|
||||
b_dq, *_ = _dq_fp4(b_e2m1, b_e4m3[:n, :k], alpha=1.0)
|
||||
out_ref = a_dq @ b_dq.transpose(-2, -1)
|
||||
|
||||
a_scale_block = to_blocked(a_e4m3, backend="triton").view(-1, k // 16)
|
||||
b_scale_block = to_blocked(b_e4m3, backend="triton").view(-1, k // 16)
|
||||
alpha = torch.tensor([1.0], device=device)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_e2m1, b_e2m1, a_scale_block, b_scale_block, alpha, torch.bfloat16
|
||||
)
|
||||
assert out.equal(out_ref.to(dtype=out.dtype))
|
||||
99
tests/kernels/quantization/test_nvfp4_scaled_mm.py
Normal file
99
tests/kernels/quantization/test_nvfp4_scaled_mm.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
72
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
72
tests/kernels/quantization/test_per_token_group_quant.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize("column_major", [False, True])
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_int8(shape, group_size: int):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
)
|
||||
|
||||
# triton ref
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
ref_q, ref_s = int8_utils.per_token_group_quant_int8(
|
||||
x,
|
||||
group_size,
|
||||
)
|
||||
|
||||
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
|
||||
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)
|
||||
200
tests/kernels/quantization/test_rocm_skinny_gemms.py
Normal file
200
tests/kernels/quantization/test_rocm_skinny_gemms.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_tensor_fp8_quant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
# Specific (N, K, M) combinations for targeted testing
|
||||
NKM_FACTORS_LLMM1 = [
|
||||
# Small, medium, large cases
|
||||
(1, 8, 16),
|
||||
(1, 32, 64),
|
||||
(1, 128, 256),
|
||||
(1, 512, 1024),
|
||||
(1, 2048, 4096),
|
||||
# Edge cases with specific K sizes
|
||||
(1, 6144, 1024),
|
||||
(1, 8192, 2048),
|
||||
# Very large case
|
||||
(1, 4096, 8192),
|
||||
]
|
||||
|
||||
NKM_FACTORS_WVSPLITK = [
|
||||
# Different batch sizes with key dimensions
|
||||
(1, 16, 16),
|
||||
(1, 64, 64),
|
||||
(2, 256, 256),
|
||||
(3, 1024, 1024),
|
||||
(4, 4096, 4096),
|
||||
# Extended K values
|
||||
(1, 9216, 512),
|
||||
(2, 10240, 1024),
|
||||
(4, 16384, 8192),
|
||||
# Minimum M constraint validation (m >= 8)
|
||||
(1, 64, 8),
|
||||
(2, 128, 8),
|
||||
(4, 256, 8),
|
||||
]
|
||||
|
||||
NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
# FP8-specific cases with K % 16 == 0
|
||||
(1, 16, 16),
|
||||
(1, 64, 64),
|
||||
(2, 512, 512),
|
||||
(3, 2048, 2048),
|
||||
(4, 4096, 4096),
|
||||
(4, 16400, 2048),
|
||||
# Extended FP8 dimensions not covered by WVSPLITK
|
||||
(1, 14336, 1024),
|
||||
(2, 24576, 2048),
|
||||
(4, 32768, 28672),
|
||||
]
|
||||
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
@torch.inference_mode()
|
||||
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
torch.manual_seed(seed)
|
||||
# TODO: Zero-centering the inputs causes errors for LLMM1!
|
||||
# Without that the numbers quickly saturate, and may
|
||||
# be giving false matches.
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out = torch.matmul(A, B.t())
|
||||
out = ops.LLMM1(B, A, rows_per_block)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, device="cuda") - 0.5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B,
|
||||
A,
|
||||
dtype,
|
||||
scale_a,
|
||||
scale_b,
|
||||
get_cu_count(),
|
||||
BIAS,
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
@@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ScaledMM kernel selection logic (CPU-only)
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_is_supported_is_abstract():
|
||||
"""Test that is_supported() is properly defined as abstract."""
|
||||
assert issubclass(ScaledMMLinearKernel, ABC)
|
||||
assert hasattr(ScaledMMLinearKernel, "is_supported")
|
||||
|
||||
|
||||
def test_cpu_kernel_implements_is_supported():
|
||||
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||
CPUScaledMMLinearKernel.is_supported
|
||||
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
# Verify it can be called as a classmethod
|
||||
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
|
||||
|
||||
def test_aiter_kernel_implements_is_supported():
|
||||
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(
|
||||
AiterScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
# (will return False on CPU, which is expected)
|
||||
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
# On CPU, it should return False with a reason about requiring ROCm
|
||||
# This validates the method works correctly even on non-ROCm platforms
|
||||
|
||||
|
||||
def test_cpu_kernel_accepts_all_configs():
|
||||
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||
configs = [
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=False,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=True,
|
||||
),
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=True,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=False,
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||
assert can_impl, (
|
||||
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
76
tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
Normal file
76
tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
FP4_DTYPE = torch.uint8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 256), (128, 128), (256, 256), (256, 128)]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_nvfp4_quant(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
current_platform.seed_everything(42)
|
||||
device = "cuda:0"
|
||||
torch.set_default_device(device)
|
||||
|
||||
x = torch.randn(shape, dtype=dtype)
|
||||
|
||||
# ref op
|
||||
ref_output = SiluAndMul().forward_native(x)
|
||||
ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
|
||||
ref_output
|
||||
).max().to(torch.float32)
|
||||
ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale)
|
||||
|
||||
# fused op
|
||||
fused_output_quant = torch.empty_like(ref_output_quant)
|
||||
fused_block_scale = torch.empty_like(ref_block_scale)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(
|
||||
fused_output_quant, fused_block_scale, x, ref_global_scale
|
||||
)
|
||||
|
||||
# check dtype
|
||||
assert ref_output_quant.dtype == FP4_DTYPE
|
||||
assert fused_output_quant.dtype == FP4_DTYPE
|
||||
assert ref_output_quant.shape == fused_output_quant.shape
|
||||
|
||||
assert ref_block_scale.dtype == FP8_DTYPE
|
||||
assert fused_block_scale.dtype == FP8_DTYPE
|
||||
assert ref_block_scale.shape == fused_block_scale.shape
|
||||
|
||||
# check dequantized output
|
||||
ref_output_dequant = dequantize_nvfp4_to_dtype(
|
||||
ref_output_quant, ref_block_scale, ref_global_scale, dtype, device
|
||||
)
|
||||
fused_output_dequant = dequantize_nvfp4_to_dtype(
|
||||
fused_output_quant, fused_block_scale, ref_global_scale, dtype, device
|
||||
)
|
||||
|
||||
atol, rtol = 3e-1, 3e-1
|
||||
torch.testing.assert_close(
|
||||
ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol
|
||||
)
|
||||
124
tests/kernels/quantization/test_triton_scaled_mm.py
Normal file
124
tests/kernels/quantization/test_triton_scaled_mm.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the triton_scaled_mm kernel
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm"
|
||||
)
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
|
||||
def torch_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||
out = scale_a * out
|
||||
out = scale_b.T * out
|
||||
out = out.to(out_dtype)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_8bit_types():
|
||||
types = [torch.int8]
|
||||
if current_platform.supports_fp8():
|
||||
types.append(current_platform.fp8_dtype())
|
||||
return types
|
||||
|
||||
|
||||
# This test is to check regressions for int8 support on ROCm.
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(
|
||||
vllm_runner, example_prompts, model_path, max_tokens, num_logprobs
|
||||
):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(33, 256, 496),
|
||||
(64, 971, 1024),
|
||||
(64, 20486, 128),
|
||||
(512, 256, 496),
|
||||
(512, 20486, 1024),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M,N,K", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_scaled_mm(
|
||||
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
|
||||
):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# NOTE: There are cases, where if the matrix is large enough, an output
|
||||
# like 65504.4 can be produced, and can easily turn into inf when
|
||||
# multiplied when using float16/bfloat16. This means one function, e.g.,
|
||||
# testing function, and another function, e.g. golden function, can
|
||||
# produce a non-inf value while the other produces an inf value, and
|
||||
# will cause assert_close/allclose to fail, even though if overflow
|
||||
# wouldn't have occurred, the values would have been "close."
|
||||
#
|
||||
# So, the values here are kept small enough to avoid this situation.
|
||||
if is_floating_point_type(in_dtype):
|
||||
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
else:
|
||||
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
|
||||
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
|
||||
|
||||
if use_scalar_scale_a:
|
||||
scale_a = torch.rand((1, 1), device=device)
|
||||
else:
|
||||
scale_a = 0.25 * torch.rand((M, 1), device=device)
|
||||
|
||||
if use_scalar_scale_b:
|
||||
scale_b = torch.rand((1, 1), device=device)
|
||||
else:
|
||||
scale_b = 0.25 * torch.rand((N, 1), device=device)
|
||||
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((N,), device=device, dtype=out_dtype)
|
||||
|
||||
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
c_actual = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(c_check, c_actual, rtol=1e-1, atol=1e-1)
|
||||
Reference in New Issue
Block a user