Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -7,11 +8,26 @@ import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# Optional imports
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
ops = None
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not VLLM_AVAILABLE:
|
||||
# Fallback to SGLang implementation
|
||||
return sglang_scaled_fp8_quant(input, scale)
|
||||
return ops.scaled_fp8_quant(input, scale)
|
||||
|
||||
|
||||
@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
|
||||
|
||||
if not VLLM_AVAILABLE:
|
||||
print("⚠️ vLLM not available, skipping comparison")
|
||||
return
|
||||
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
|
||||
@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
batch_size_range = [16] # Single batch size for CI
|
||||
seq_len_range = [64] # Single sequence length for CI
|
||||
else:
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user