Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)
This commit is contained in:
67
sgl-kernel/tests/test_awq_dequant.py
Normal file
67
sgl-kernel/tests/test_awq_dequant.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"qweight_row,qweight_col",
|
||||
list(
|
||||
itertools.product(
|
||||
[3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128]
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_awq_dequant_compare_implementations(
|
||||
qweight_row: int,
|
||||
qweight_col: int,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Run both implementations
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
# Compare results
|
||||
torch.testing.assert_close(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the specific test function directly
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user