fix accuracy issue (#4376)
This commit is contained in:
@@ -22,9 +22,10 @@ def vllm_per_token_quant_fp8(
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros((input.size(0), 1), device=input.device, dtype=torch.float32)
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@@ -36,6 +37,9 @@ def calculate_diff(batch_size: int, seq_len: int):
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
||||
|
||||
Reference in New Issue
Block a user