Fix quant kernel test errors and benchmark wrong output speeds (#7604)
This commit is contained in:
@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
|
||||
)
|
||||
|
||||
|
||||
# TODO maybe unify int8 and fp8 code later
|
||||
def per_token_group_quant_8bit(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
dst_dtype: torch.dtype,
|
||||
eps: float = 1e-10,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
|
||||
|
||||
if dst_dtype == torch.int8:
|
||||
assert not column_major_scales
|
||||
assert not scale_tma_aligned
|
||||
assert not scale_ue8m0
|
||||
return per_token_group_quant_int8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
dtype=dst_dtype,
|
||||
)
|
||||
|
||||
return per_token_group_quant_fp8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
|
||||
def sglang_per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# TODO maybe unify int8 and fp8 code later
|
||||
def sglang_per_token_group_quant_8bit(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
dst_dtype: torch.dtype,
|
||||
eps: float = 1e-10,
|
||||
column_major_scales: bool = False,
|
||||
scale_tma_aligned: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
):
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
sglang_per_token_group_quant_int8,
|
||||
)
|
||||
|
||||
if dst_dtype == torch.int8:
|
||||
assert not column_major_scales
|
||||
assert not scale_tma_aligned
|
||||
return sglang_per_token_group_quant_int8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
dtype=dst_dtype,
|
||||
)
|
||||
|
||||
return sglang_per_token_group_quant_fp8(
|
||||
x=x,
|
||||
group_size=group_size,
|
||||
eps=eps,
|
||||
column_major_scales=column_major_scales,
|
||||
scale_tma_aligned=scale_tma_aligned,
|
||||
scale_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = fp8_dtype,
|
||||
|
||||
@@ -176,6 +176,27 @@ def replace_parameter(
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
|
||||
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype == torch.float8_e4m3fn
|
||||
|
||||
a_u8 = a.view(torch.uint8)
|
||||
b_u8 = b.view(torch.uint8)
|
||||
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
|
||||
|
||||
numel = a.numel()
|
||||
|
||||
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
|
||||
count_tiny_diff = (diff_u8 >= 1).sum().item()
|
||||
count_large_diff = (diff_u8 >= 2).sum().item()
|
||||
|
||||
assert (
|
||||
(count_diff_sign == 0)
|
||||
and (count_tiny_diff / numel < 0.005)
|
||||
and (count_large_diff == 0)
|
||||
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
|
||||
Reference in New Issue
Block a user