Fix quant kernel test errors and benchmark wrong output speeds (#7604)

This commit is contained in:
fzyzcjy
2025-08-21 18:48:41 +08:00
committed by GitHub
parent 55d336cb08
commit e85cb1ce9d
4 changed files with 205 additions and 463 deletions

View File

@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
)
# TODO maybe unify int8 and fp8 code later
def per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
if dst_dtype == torch.int8:
assert not column_major_scales
assert not scale_tma_aligned
assert not scale_ue8m0
return per_token_group_quant_int8(
x=x,
group_size=group_size,
eps=eps,
dtype=dst_dtype,
)
return per_token_group_quant_fp8(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
return x_q, x_s
# TODO maybe unify int8 and fp8 code later
def sglang_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
):
from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8,
)
if dst_dtype == torch.int8:
assert not column_major_scales
assert not scale_tma_aligned
return sglang_per_token_group_quant_int8(
x=x,
group_size=group_size,
eps=eps,
dtype=dst_dtype,
)
return sglang_per_token_group_quant_fp8(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
def sglang_per_token_quant_fp8(
x: torch.Tensor,
dtype: torch.dtype = fp8_dtype,

View File

@@ -176,6 +176,27 @@ def replace_parameter(
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
assert a.shape == b.shape
assert a.dtype == b.dtype == torch.float8_e4m3fn
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
numel = a.numel()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 >= 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
assert (
(count_diff_sign == 0)
and (count_tiny_diff / numel < 0.005)
and (count_large_diff == 0)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):