Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,16 +1,34 @@
|
||||
import itertools
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# Optional vLLM import
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation
|
||||
return sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
device=device,
|
||||
)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping comparison")
|
||||
return
|
||||
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
qweight_row_range = [128] # Single row size for CI
|
||||
qweight_cols_range = [16] # Single column size for CI
|
||||
else:
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
|
||||
@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
x_names=["qweight_row", "qweight_col"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="awq-dequantize-performance",
|
||||
args={},
|
||||
@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
if not VLLM_AVAILABLE:
|
||||
return (0, 0, 0)
|
||||
fn = lambda: vllm_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(qweight_row=3584, qweight_col=448)
|
||||
# Simplify for CI environment
|
||||
if IS_CI:
|
||||
qweight_row, qweight_col = 128, 16 # Smaller values for CI
|
||||
else:
|
||||
qweight_row, qweight_col = 3584, 448
|
||||
|
||||
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user