154 lines
4.4 KiB
Python
154 lines
4.4 KiB
Python
import itertools
|
|
import os
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.testing
|
|
from sgl_kernel import awq_dequantize
|
|
|
|
# Optional vLLM import
|
|
try:
|
|
from vllm import _custom_ops as ops
|
|
|
|
VLLM_AVAILABLE = True
|
|
except ImportError:
|
|
ops = None
|
|
VLLM_AVAILABLE = False
|
|
|
|
# CI environment detection
|
|
IS_CI = (
|
|
os.getenv("CI", "false").lower() == "true"
|
|
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
|
)
|
|
|
|
|
|
def vllm_awq_dequantize(
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if not VLLM_AVAILABLE:
|
|
# Fallback to SGLang implementation
|
|
return sglang_awq_dequantize(qweight, scales, qzeros)
|
|
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
|
|
|
|
|
def sglang_awq_dequantize(
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
return awq_dequantize(qweight, scales, qzeros)
|
|
|
|
|
|
def calculate_diff(qweight_row: int, qweight_col: int):
|
|
"""Calculate difference between VLLM and SGLang implementations."""
|
|
device = torch.device("cuda")
|
|
qweight = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(qweight_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
group_size = qweight_row
|
|
scales_row = qweight_row // group_size
|
|
scales_col = qweight_col * 8
|
|
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
|
qzeros = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(scales_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
if not VLLM_AVAILABLE:
|
|
print("⚠️ vLLM not available, skipping comparison")
|
|
return
|
|
|
|
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
|
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
|
|
|
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
|
|
|
if torch.allclose(
|
|
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
|
):
|
|
print("✅ All implementations match")
|
|
else:
|
|
print("❌ Implementations differ")
|
|
|
|
|
|
# CI environment uses simplified parameters
|
|
if IS_CI:
|
|
qweight_row_range = [128] # Single row size for CI
|
|
qweight_cols_range = [16] # Single column size for CI
|
|
else:
|
|
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
|
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
|
|
|
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["qweight_row", "qweight_col"],
|
|
x_vals=configs,
|
|
line_arg="provider",
|
|
line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
|
|
line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
|
|
styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
|
|
ylabel="us",
|
|
plot_name="awq-dequantize-performance",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(qweight_row, qweight_col, provider):
|
|
dtype = torch.float16
|
|
device = torch.device("cuda")
|
|
qweight = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(qweight_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
group_size = qweight_row
|
|
scales_row = qweight_row // group_size
|
|
scales_col = qweight_col * 8
|
|
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
|
qzeros = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(scales_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "vllm":
|
|
if not VLLM_AVAILABLE:
|
|
return (0, 0, 0)
|
|
fn = lambda: vllm_awq_dequantize(
|
|
qweight.clone(), scales.clone(), qzeros.clone()
|
|
)
|
|
elif provider == "sglang":
|
|
fn = lambda: sglang_awq_dequantize(
|
|
qweight.clone(), scales.clone(), qzeros.clone()
|
|
)
|
|
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
|
|
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Simplify for CI environment
|
|
if IS_CI:
|
|
qweight_row, qweight_col = 128, 16 # Smaller values for CI
|
|
else:
|
|
qweight_row, qweight_col = 3584, 448
|
|
|
|
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
|
|
benchmark.run(print_data=True)
|