Fix CI xeon test with triton 3.3.1 (#8086)
This commit is contained in:
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
|
|||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
@@ -37,6 +38,7 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_cpu = is_cpu()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -1168,7 +1170,7 @@ def scaled_fp8_quant(
|
|||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
fp8_autotune = triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
triton.Config({"BLOCK_M": block_m}, num_warps=num_warps)
|
||||||
for block_m in [16, 32, 64, 128]
|
for block_m in [16, 32, 64, 128]
|
||||||
@@ -1176,6 +1178,8 @@ def scaled_fp8_quant(
|
|||||||
],
|
],
|
||||||
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
key=["K", "BLOCK_K", "M_ALIGNMENT"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||||
a, # (M, K):(K, 1)
|
a, # (M, K):(K, 1)
|
||||||
@@ -1221,6 +1225,12 @@ def _per_token_group_quant_fp8_hopper_moe_mn_major(
|
|||||||
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m)
|
||||||
|
|
||||||
|
|
||||||
|
if not _is_cpu:
|
||||||
|
_per_token_group_quant_fp8_hopper_moe_mn_major = fp8_autotune(
|
||||||
|
_per_token_group_quant_fp8_hopper_moe_mn_major
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
def per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
expert_offsets: torch.Tensor,
|
expert_offsets: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user