[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)

This commit is contained in:
Hubert Lu
2025-07-21 17:33:19 -07:00
committed by GitHub
parent 69adc4f81c
commit e50109f2ed
8 changed files with 156 additions and 69 deletions

View File

@@ -3,8 +3,13 @@
import pytest
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
qweight = qweight.to(fp8_dtype)
return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype):
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
)
if is_cuda:
if _is_cuda or _is_hip:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
finfo = torch.finfo(fp8_dtype)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max
)
qweight = qweight.to(torch.float8_e4m3fn)
qweight = qweight.to(fp8_dtype)
return qweight
def dequantize_per_token(tensor, inv_scale, dtype):