From b188a89a5d09ba634c77c34a2407e95dea5826b8 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Wed, 16 Jul 2025 17:12:23 +0800 Subject: [PATCH] Fix CI xeon test with triton 3.3.1 (#8086) --- python/sglang/srt/layers/quantization/fp8_kernel.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 7d73c5bc2..79504265c 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -29,6 +29,7 @@ from sglang.srt.utils import ( direct_register_custom_op, get_device_core_count, get_device_name, + is_cpu, is_cuda, is_hip, log_info_on_rank0, @@ -37,6 +38,7 @@ from sglang.srt.utils import ( _is_hip = is_hip() _is_cuda = is_cuda() +_is_cpu = is_cpu() if _is_cuda: from sgl_kernel import ( @@ -1168,7 +1170,7 @@ def scaled_fp8_quant( return output, scale -@triton.autotune( +fp8_autotune = triton.autotune( configs=[ triton.Config({"BLOCK_M": block_m}, num_warps=num_warps) for block_m in [16, 32, 64, 128] @@ -1176,6 +1178,8 @@ def scaled_fp8_quant( ], key=["K", "BLOCK_K", "M_ALIGNMENT"], ) + + @triton.jit def _per_token_group_quant_fp8_hopper_moe_mn_major( a, # (M, K):(K, 1) @@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major( tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m) +if not _is_cpu: + _per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune( + _per_token_group_quant_fp8_hopper_moe_mn_major + ) + + def per_token_group_quant_fp8_hopper_moe_mn_major( A: torch.Tensor, expert_offsets: torch.Tensor,