Fix CI xeon test with triton 3.3.1 (#8086)
This commit is contained in:
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_device_core_count,
|
||||
get_device_name,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
log_info_on_rank0,
|
||||
@@ -37,6 +38,7 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -1168,7 +1170,7 @@ def scaled_fp8_quant(
|
||||
return output, scale
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
fp8_autotune = triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
||||
for block_m in [16, 32, 64, 128]
|
||||
@@ -1176,6 +1178,8 @@ def scaled_fp8_quant(
|
||||
],
|
||||
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
a, # (M, K):(K, 1)
|
||||
@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
||||
|
||||
|
||||
if not _is_cpu:
|
||||
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
|
||||
_per_token_group_quant_fp8_hopper_moe_mn_major
|
||||
)
|
||||
|
||||
|
||||
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
A: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user