Fix per token fp8 quant precision (#4362)

This commit is contained in:
Qingquan Song
2025-03-12 21:19:05 -07:00
committed by GitHub
parent 817d43705c
commit 4068e01292
3 changed files with 5 additions and 13 deletions

View File

@@ -21,18 +21,16 @@ def vllm_per_token_quant_fp8(
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
scale = scale.reshape(-1, 1)
return output, scale
@pytest.mark.parametrize(
"num_tokens,hidden_dim",
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
list(itertools.product([32, 64, 128, 256, 512], [128, 256, 512, 2048, 4096])),
)
def test_per_token_quant_compare_implementations(
num_tokens: int,
@@ -44,7 +42,7 @@ def test_per_token_quant_compare_implementations(
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5)
torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
)