Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)
Co-authored-by: Stefan He <bhe@linkedin.com>
This commit is contained in:
88
python/sglang/test/test_custom_ops.py
Normal file
88
python/sglang/test/test_custom_ops.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_tensor(dtype) -> None:
|
||||
|
||||
def quantize_ref_per_tensor(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_tensor(tensor, inv_scale, dtype):
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
# Note that we use a shape % 8 != 0 to cover edge cases,
|
||||
# because scaled_fp8_quant is vectorized by 8.
|
||||
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Test Per Tensor Dynamic quantization
|
||||
# scale = max(abs(x)) / FP8_E4M3_MAX
|
||||
y, scale = scaled_fp8_quant(x, None)
|
||||
ref_y = quantize_ref_per_tensor(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_tensor(y, scale, dtype),
|
||||
dequantize_per_tensor(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
# Test Per Tensor Static quantization
|
||||
y, _ = scaled_fp8_quant(x, scale)
|
||||
ref_y = quantize_ref_per_tensor(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_tensor(y, scale, dtype),
|
||||
dequantize_per_tensor(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
|
||||
if is_cuda:
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
|
||||
def quantize_ref_per_token(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(
|
||||
min=finfo.min, max=finfo.max
|
||||
)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
def dequantize_per_token(tensor, inv_scale, dtype):
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
# Note that we use a shape % 8 = 0,
|
||||
# because per_token_quant_fp8 is vectorized by 8 elements.
|
||||
x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Test Per Tensor Dynamic quantization
|
||||
# scale = max(abs(x)) / FP8_E4M3_MAX
|
||||
y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
|
||||
ref_y = quantize_ref_per_token(x, scale)
|
||||
torch.testing.assert_close(y, ref_y)
|
||||
torch.testing.assert_close(
|
||||
dequantize_per_token(y, scale, dtype),
|
||||
dequantize_per_token(ref_y, scale, dtype),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user