From a167fd0bcb9ef4b0f4331a109e40c8cdc770b026 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:38:30 +0800 Subject: [PATCH] [code style] Clean dead triton kernel code in fused_moe and useless vllm_ops import (#8310) --- .../layers/moe/fused_moe_triton/fused_moe.py | 249 ++---------------- .../compressed_tensors_moe.py | 11 +- .../sglang/srt/layers/quantization/utils.py | 9 - 3 files changed, 27 insertions(+), 242 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 9c13c7e9d..267b594c0 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -53,9 +53,7 @@ elif _is_hip: from aiter import moe_sum except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") -else: - from vllm import _custom_ops as vllm_ops - from vllm._custom_ops import scaled_fp8_quant + if _is_cuda or _is_hip: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size @@ -63,9 +61,6 @@ if _is_cuda or _is_hip: logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 -enable_moe_align_block_size_triton = bool( - int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) -) @triton.jit @@ -533,190 +528,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -@triton.jit -def moe_align_block_size_stage1( - topk_ids_ptr, - tokens_cnts_ptr, - num_experts: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - - start_idx = pid * tokens_per_thread - - off_c = (pid + 1) * num_experts - - for i in range(tokens_per_thread): - if start_idx + i < numel: - idx = tl.load(topk_ids_ptr + start_idx + i) - token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) - tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) - - -@triton.jit -def moe_align_block_size_stage2( - tokens_cnts_ptr, - num_experts: tl.constexpr, -): - pid = tl.program_id(0) - - last_cnt = 0 - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) - last_cnt = last_cnt + token_cnt - tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) - - -@triton.jit -def moe_align_block_size_stage3( - total_tokens_post_pad_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, -): - last_cumsum = 0 - off_cnt = num_experts * num_experts - for i in range(1, num_experts + 1): - token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) - last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size - tl.store(cumsum_ptr + i, last_cumsum) - tl.store(total_tokens_post_pad_ptr, last_cumsum) - - -@triton.jit -def moe_align_block_size_stage4( - topk_ids_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - tokens_cnts_ptr, - cumsum_ptr, - num_experts: tl.constexpr, - block_size: tl.constexpr, - numel: tl.constexpr, - tokens_per_thread: tl.constexpr, -): - pid = tl.program_id(0) - start_idx = tl.load(cumsum_ptr + pid) - end_idx = tl.load(cumsum_ptr + pid + 1) - - for i in range(start_idx, end_idx, block_size): - tl.store(expert_ids_ptr + i // block_size, pid) - - start_idx = pid * tokens_per_thread - off_t = pid * num_experts - - for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): - expert_id = tl.load(topk_ids_ptr + i) - token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) - rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) - tl.store(sorted_token_ids_ptr + rank_post_pad, i) - tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) - - -def moe_align_block_size_triton( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - numel = topk_ids.numel() - grid = (num_experts,) - tokens_cnts = torch.zeros( - (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device - ) - cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) - tokens_per_thread = ceil_div(numel, num_experts) - - moe_align_block_size_stage1[grid]( - topk_ids, - tokens_cnts, - num_experts, - numel, - tokens_per_thread, - ) - moe_align_block_size_stage2[grid]( - tokens_cnts, - num_experts, - ) - moe_align_block_size_stage3[(1,)]( - num_tokens_post_pad, - tokens_cnts, - cumsum, - num_experts, - block_size, - ) - moe_align_block_size_stage4[grid]( - topk_ids, - sorted_token_ids, - expert_ids, - tokens_cnts, - cumsum, - num_experts, - block_size, - numel, - tokens_per_thread, - ) - - -@triton.jit -def init_sorted_ids_and_cumsum_buffer_kernel( - sorted_ids_ptr, - cumsum_buffer_ptr, - max_num_tokens_padded, - topk_ids_numel, - num_experts: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - ALIGNED_NUM_EXPERTS_P1: tl.constexpr, -): - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - - sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE) - - if pid < sorted_ids_blocks: - mask = offsets < max_num_tokens_padded - tl.store( - sorted_ids_ptr + offsets, - tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32), - mask=mask, - ) - elif pid == sorted_ids_blocks: - offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1) - mask_e = offset_e < num_experts + 1 - tl.store( - cumsum_buffer_ptr + offset_e, - tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32), - mask=mask_e, - ) - - -def init_sorted_ids_and_cumsum_buffer( - max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda" -): - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device) - cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) - - BLOCK_SIZE = 1024 - sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE) - grid = (sorted_ids_blocks + 1,) - - init_sorted_ids_and_cumsum_buffer_kernel[grid]( - sorted_ids, - cumsum_buffer, - max_num_tokens_padded, - topk_ids_numel, - num_experts, - BLOCK_SIZE, - next_power_of_2(num_experts + 1), - ) - - return sorted_ids, cumsum_buffer - - def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -766,42 +577,32 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - if enable_moe_align_block_size_triton: + + cumsum_buffer = torch.empty( + (num_experts + 1,), dtype=torch.int32, device=topk_ids.device + ) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + + # Threshold based on benchmark results + fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 + if not fuse_sorted_ids_padding: sorted_ids.fill_(topk_ids.numel()) - moe_align_block_size_triton( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - else: - cumsum_buffer = torch.empty( - (num_experts + 1,), dtype=torch.int32, device=topk_ids.device - ) - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, - dtype=torch.int32, - device=topk_ids.device, - ) - # Threshold based on benchmark results - fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 - if not fuse_sorted_ids_padding: - sorted_ids.fill_(topk_ids.numel()) - - sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - fuse_sorted_ids_padding, - ) + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + fuse_sorted_ids_padding, + ) return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index af1f6cbf7..525a75069 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -28,15 +28,6 @@ if TYPE_CHECKING: CompressedTensorsConfig, ) -_is_cuda = is_cuda() -_is_npu = is_npu() -_is_cpu_amx_available = cpu_has_amx_support() -_is_cpu = is_cpu() -_is_hip = is_hip() - -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): - from vllm import _custom_ops as vllm_ops - from vllm._custom_ops import scaled_fp8_quant try: import vllm @@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): requires_grad=False, ) + from vllm import _custom_ops as vllm_ops + marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack( layer.w13_weight_packed, layer.w13_g_idx_sort_indices, diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 8904247a6..9b19e0309 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig -_is_cuda = is_cuda() -_is_npu = is_npu() -_is_cpu_amx_available = cpu_has_amx_support() -_is_cpu = is_cpu() -_is_hip = is_hip() - -if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip): - from vllm._custom_ops import scaled_fp8_quant - def is_layer_skipped( prefix: str,