[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)
This commit is contained in:
@@ -3,8 +3,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
def quantize_ref_per_tensor(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
qweight = qweight.to(fp8_dtype)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
||||
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
)
|
||||
|
||||
|
||||
if is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
||||
def quantize_ref_per_token(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(
|
||||
min=finfo.min, max=finfo.max
|
||||
)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
qweight = qweight.to(fp8_dtype)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_token(tensor, inv_scale, dtype):
|
||||
|
||||
Reference in New Issue
Block a user