Sync from v0.13
This commit is contained in:
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
342
tests/kernels/quantization/test_cutlass_w4a8_moe.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for the CUTLASS-based W4A8 grouped GEMM kernel and the full MoE layer.
|
||||
"""
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
IS_SUPPORTED_BY_GPU = (
|
||||
current_platform.is_cuda() and current_platform.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
"""
|
||||
Quantize weights into W4 and compute reference dequantized weights.
|
||||
|
||||
Encoding/reordering of weights and packing of scales is deferred
|
||||
until after all experts are combined.
|
||||
"""
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# Since scales are later cast to fp8, recompute w_ref in atype here.
|
||||
w_ref = (
|
||||
w_q.to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# Bit mask prevents sign extension of int4 when packing.
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
# Make weights row-major (N, K).
|
||||
w_q = w_q.t().contiguous()
|
||||
|
||||
return w_ref, w_q, w_s.to(atype), w_zp
|
||||
|
||||
|
||||
def cutlass_preprocess(
|
||||
w_q_experts: list[torch.Tensor], w_s_experts: list[torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Reorder/encode expert weights and pack scales.
|
||||
|
||||
Returns:
|
||||
w_q_packed: Packed/encoded int4 weights for all experts.
|
||||
w_s_packed: Packed fp8 scales for all experts.
|
||||
packed_layout: Layout/stride metadata for grouped GEMM.
|
||||
"""
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(torch.stack(w_s_experts))
|
||||
w_q_packed, packed_layout = ops.cutlass_encode_and_reorder_int4b_grouped(
|
||||
torch.stack(w_q_experts)
|
||||
) # expects dim 3
|
||||
return w_q_packed, w_s_packed, packed_layout
|
||||
|
||||
|
||||
GROUP_SIZE = 128
|
||||
# (num_experts, N, K)
|
||||
TEST_SHAPES = [
|
||||
(8, 512, 2048),
|
||||
(8, 2048, 2048),
|
||||
(64, 512, 1024),
|
||||
(64, 2048, 2048),
|
||||
(4, 2048, 768),
|
||||
(8, 768, 2048),
|
||||
(64, 1536, 2048),
|
||||
(128, 8192, 4096), # test overflow int32
|
||||
]
|
||||
ALIGNMENT = 16 # torch._scaled_mm alignment for M, needed for reference check
|
||||
|
||||
|
||||
@dataclass
|
||||
class MoETestSetup:
|
||||
num_experts: int
|
||||
K: int
|
||||
N: int
|
||||
Ms: list[int]
|
||||
M_full: int
|
||||
a: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a_strides: torch.Tensor
|
||||
out: torch.Tensor
|
||||
c_strides: torch.Tensor
|
||||
per_tok_scales: torch.Tensor
|
||||
per_chan_scales: torch.Tensor
|
||||
w_refs: list[torch.Tensor]
|
||||
w_q_packed: torch.Tensor
|
||||
w_s_packed: torch.Tensor
|
||||
problem_sizes: torch.Tensor
|
||||
expert_offsets: torch.Tensor
|
||||
b_strides: torch.Tensor
|
||||
group_scale_strides: torch.Tensor
|
||||
|
||||
|
||||
def make_moe_test_setup(
|
||||
num_experts: int,
|
||||
K: int,
|
||||
N: int,
|
||||
*,
|
||||
alignment: int = ALIGNMENT,
|
||||
max_blocks: int = 64,
|
||||
device: str = "cuda",
|
||||
random_zero: bool = False,
|
||||
) -> MoETestSetup:
|
||||
"""Create a full set of tensors for testing cutlass_w4a8_moe_mm."""
|
||||
|
||||
assert K % GROUP_SIZE == 0
|
||||
# Token counts per expert (multiples of `alignment`).
|
||||
Ms = [alignment * random.randint(1, max_blocks) for _ in range(num_experts)]
|
||||
|
||||
# set random experts to 0 tokens
|
||||
if random_zero and num_experts > 1:
|
||||
num_zero = max(1, num_experts // 8)
|
||||
zero_indices = random.sample(range(num_experts), k=num_zero)
|
||||
for idx in zero_indices:
|
||||
Ms[idx] = 0
|
||||
|
||||
M_full = sum(Ms)
|
||||
assert M_full > 0
|
||||
|
||||
# Activations.
|
||||
a = to_fp8(torch.randn((M_full, K), device=device))
|
||||
a_ref = a.to(torch.float32)
|
||||
a_strides = torch.full((num_experts,), K, dtype=torch.int64, device=device)
|
||||
|
||||
# Output buffer.
|
||||
out = torch.empty((M_full, N), dtype=torch.bfloat16, device=device)
|
||||
c_strides = torch.full((num_experts,), N, dtype=torch.int64, device=device)
|
||||
|
||||
# Channel/token scales.
|
||||
per_tok_scales = torch.randn((M_full, 1), dtype=torch.float32, device=device)
|
||||
per_chan_scales = torch.randn(
|
||||
(num_experts, N, 1), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# Expert weights and scales.
|
||||
wtype = scalar_types.int4
|
||||
atype = stype = torch.float8_e4m3fn
|
||||
w_refs, w_qs, w_ss = [], [], []
|
||||
for _ in range(num_experts):
|
||||
b = to_fp8(torch.randn((K, N), device=device))
|
||||
w_ref, w_q, w_s, _ = cutlass_quantize(
|
||||
atype, b.to(torch.float16), wtype, stype, GROUP_SIZE, zero_points=False
|
||||
)
|
||||
w_refs.append(w_ref)
|
||||
w_qs.append(w_q)
|
||||
w_ss.append(w_s)
|
||||
|
||||
w_q_packed, w_s_packed, packed_layout = cutlass_preprocess(w_qs, w_ss)
|
||||
|
||||
problem_sizes = torch.tensor(
|
||||
[[N, M, K] for M in Ms], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
expert_offsets = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int64),
|
||||
torch.cumsum(torch.tensor(Ms, dtype=torch.int64), dim=0)[:-1],
|
||||
]
|
||||
).to(device=device)
|
||||
|
||||
# B strides and group scale strides.
|
||||
b_strides = packed_layout
|
||||
group_scale_strides = torch.zeros(
|
||||
(num_experts, 2), dtype=torch.int64, device=device
|
||||
)
|
||||
group_scale_strides[:, 0] = N
|
||||
|
||||
return MoETestSetup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
Ms=Ms,
|
||||
M_full=M_full,
|
||||
a=a,
|
||||
a_ref=a_ref,
|
||||
a_strides=a_strides,
|
||||
out=out,
|
||||
c_strides=c_strides,
|
||||
per_tok_scales=per_tok_scales,
|
||||
per_chan_scales=per_chan_scales,
|
||||
w_refs=w_refs,
|
||||
w_q_packed=w_q_packed,
|
||||
w_s_packed=w_s_packed,
|
||||
problem_sizes=problem_sizes,
|
||||
expert_offsets=expert_offsets,
|
||||
b_strides=b_strides,
|
||||
group_scale_strides=group_scale_strides,
|
||||
)
|
||||
|
||||
|
||||
def compute_moe_reference_output(setup: MoETestSetup) -> torch.Tensor:
|
||||
"""Compute reference output using torch._scaled_mm per expert."""
|
||||
out_ref = torch.empty_like(setup.out)
|
||||
|
||||
ends = torch.cumsum(torch.tensor(setup.Ms), 0).tolist()
|
||||
starts = setup.expert_offsets.cpu().tolist()
|
||||
|
||||
for i in range(setup.num_experts):
|
||||
start, end = starts[i], ends[i]
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
out_ref_i = torch._scaled_mm(
|
||||
setup.a_ref[start:end].to(torch.float8_e4m3fn),
|
||||
setup.w_refs[i].to(torch.float8_e4m3fn).t().contiguous().t(),
|
||||
setup.per_tok_scales[start:end], # (M, 1)
|
||||
setup.per_chan_scales[i].reshape(1, -1), # (1, N)
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True,
|
||||
)
|
||||
out_ref[start:end] = out_ref_i
|
||||
|
||||
return out_ref
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("shape", TEST_SHAPES)
|
||||
@pytest.mark.parametrize("random_zero", [True, False])
|
||||
def test_cutlass_w4a8_moe_mm_end_to_end(shape, random_zero):
|
||||
num_experts, N, K = shape
|
||||
current_platform.seed_everything(42)
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts, K=K, N=N, max_blocks=64, random_zero=random_zero
|
||||
)
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
setup.out,
|
||||
setup.a,
|
||||
setup.w_q_packed,
|
||||
setup.per_tok_scales,
|
||||
setup.per_chan_scales,
|
||||
setup.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
setup.expert_offsets,
|
||||
setup.problem_sizes,
|
||||
setup.a_strides,
|
||||
setup.b_strides,
|
||||
setup.c_strides,
|
||||
setup.group_scale_strides,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
torch.testing.assert_close(setup.out, out_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
class W4A8MoELayer(torch.nn.Module):
|
||||
"""
|
||||
Minimal wrapper module to test cuda graphs
|
||||
"""
|
||||
|
||||
def __init__(self, setup: MoETestSetup):
|
||||
super().__init__()
|
||||
self.setup = setup
|
||||
|
||||
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||||
s = self.setup
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
s.out,
|
||||
a,
|
||||
s.w_q_packed,
|
||||
s.per_tok_scales,
|
||||
s.per_chan_scales,
|
||||
s.w_s_packed,
|
||||
GROUP_SIZE,
|
||||
s.expert_offsets,
|
||||
s.problem_sizes,
|
||||
s.a_strides,
|
||||
s.b_strides,
|
||||
s.c_strides,
|
||||
s.group_scale_strides,
|
||||
)
|
||||
return s.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU,
|
||||
reason="W4A8 Grouped GEMM is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_w4a8_moe_mm_cuda_graph():
|
||||
current_platform.seed_everything(42)
|
||||
# Fixed config for CUDA graph test (single parameter point).
|
||||
num_experts = 8
|
||||
K = 512
|
||||
N = 2048
|
||||
|
||||
setup = make_moe_test_setup(
|
||||
num_experts=num_experts,
|
||||
K=K,
|
||||
N=N,
|
||||
max_blocks=32,
|
||||
)
|
||||
|
||||
# Construct model that calls the grouped GEMM kernel.
|
||||
model = W4A8MoELayer(setup)
|
||||
|
||||
# Build reference output once.
|
||||
out_ref = compute_moe_reference_output(setup)
|
||||
|
||||
# Capture and run the model in a CUDA graph.
|
||||
a_static = setup.a.clone() # static input tensor for graph replay
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out_static = model(a_static)
|
||||
|
||||
out_static.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(out_static, out_ref, rtol=1e-2, atol=1e-2)
|
||||
Reference in New Issue
Block a user