Fix sgl-kernel benchmark dead code (#11022)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
# CI environment detection
|
||||
IS_CI = (
|
||||
os.getenv("CI", "false").lower() == "true"
|
||||
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
# CI environment uses simplified parameters
|
||||
if IS_CI:
|
||||
num_tokens_range = [64] # Single value for CI
|
||||
hidden_dim_range = [1536] # Single value for CI
|
||||
group_size_range = [128] # Keep as is
|
||||
dst_dtype_range = [fp8_type_] # Keep as is
|
||||
else:
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
flags_range = [
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||
|
||||
fn, kernel_names = {
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
|
||||
"sglang": (
|
||||
sglang_per_token_group_quant_8bit,
|
||||
"per_token_group_quant_8bit_kernel",
|
||||
|
||||
Reference in New Issue
Block a user