Fix CI xeon test with triton 3.3.1 (#8086)

This commit is contained in:
YanbingJiang
2025-07-16 17:12:23 +08:00
committed by GitHub
parent 497efe747d
commit b188a89a5d

View File

@@ -29,6 +29,7 @@ from sglang.srt.utils import (
direct_register_custom_op,
get_device_core_count,
get_device_name,
is_cpu,
is_cuda,
is_hip,
log_info_on_rank0,
@@ -37,6 +38,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import (
@@ -1168,7 +1170,7 @@ def scaled_fp8_quant(
return output, scale
@triton.autotune(
fp8_autotune = triton.autotune(
configs=[
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
for block_m in [16, 32, 64, 128]
@@ -1176,6 +1178,8 @@ def scaled_fp8_quant(
],
key=["K", "BLOCK_K", "M_ALIGNMENT"],
)
@triton.jit
def _per_token_group_quant_fp8_hopper_moe_mn_major(
a, # (M, K):(K, 1)
@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major(
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
if not _is_cpu:
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
_per_token_group_quant_fp8_hopper_moe_mn_major
)
def per_token_group_quant_fp8_hopper_moe_mn_major(
A: torch.Tensor,
expert_offsets: torch.Tensor,