diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index 6d8aee852..be3ed3af4 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -1,16 +1,18 @@ from contextlib import contextmanager from typing import Any, Dict, Optional -from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( - fused_experts, +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts +from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import ( get_config_file_name, - moe_align_block_size, try_get_optimal_moe_config, ) from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, FusedMoeWeightScaleSupported, ) +from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import ( + moe_align_block_size, +) _config: Optional[Dict[str, Any]] = None diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index c961dd554..4660df676 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -5,39 +5,27 @@ from __future__ import annotations import functools -import json -import logging import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import List, Optional import torch -import triton import triton.language as tl from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import StandardTopKOutput -from sglang.srt.layers.quantization.fp8_kernel import ( - per_token_group_quant_fp8, - scaled_fp8_quant, - sglang_per_token_group_quant_fp8, -) -from sglang.srt.layers.quantization.int8_kernel import ( - per_token_group_quant_int8, - per_token_quant_int8, - sglang_per_token_group_quant_int8, -) from sglang.srt.utils import ( - ceil_div, cpu_has_amx_support, direct_register_custom_op, get_bool_env_var, - get_device_name, is_cpu, is_cuda, is_hip, - next_power_of_2, ) +from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config +from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton +from .moe_align_block_size import moe_align_block_size + _is_hip = is_hip() _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() @@ -59,954 +47,9 @@ elif _is_hip: else: from vllm import _custom_ops as vllm_ops - -if _is_cuda or _is_hip: - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - - -logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 -@triton.jit -def write_zeros_to_output( - c_ptr, - stride_cm, - stride_cn, - pid_n, - N, - offs_token, - token_mask, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - compute_type, -): - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -@triton.jit -def fused_moe_kernel_gptq_awq( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - b_scale_ptr, - b_zp_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N: tl.constexpr, - K: tl.constexpr, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bse, - stride_bsk, - stride_bsn, - stride_bze, - stride_bzk, - stride_bzn, - group_size: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - has_zp: tl.constexpr, - use_int4_w4a16: tl.constexpr, - use_int8_w8a16: tl.constexpr, - even_Ks: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - if off_experts == -1: - # ----------------------------------------------------------- - # Write back zeros to the output when the expert is not - # in the current expert parallel rank. - write_zeros_to_output( - c_ptr, - stride_cm, - stride_cn, - pid_n, - N, - offs_token, - token_mask, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - compute_type, - ) - return - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) - - if use_int4_w4a16: - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] // 2) * stride_bk - + offs_bn[None, :] * stride_bn - ) - b_shifter = (offs_k[:, None] % 2) * 4 - elif use_int8_w8a16: - b_ptrs = ( - b_ptr - + off_experts * stride_be - + offs_k[:, None] * stride_bk - + offs_bn[None, :] * stride_bn - ) - - if not has_zp and use_int4_w4a16: - b_zp_num = 8 - if not has_zp and use_int8_w8a16: - b_zp_num = 128 - elif has_zp and use_int4_w4a16: - b_zp_shifter = (offs_bn[None, :] % 2) * 4 - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the - # K dimension. - - if not even_Ks: - k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K - k_other = 0.0 - else: - k_mask = None - k_other = None - - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs) - if use_int4_w4a16: - b = (b >> b_shifter) & 0xF - - b_scale_ptrs = ( - b_scale_ptr - + off_experts * stride_bse - + offs_bn[None, :] * stride_bsn - + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk - ) - b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) - b_scale = b_scale.to(tl.float32) - - if has_zp and use_int4_w4a16: - offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = ( - b_zp_ptr - + off_experts * stride_bze - + (offs_bn[None, :] // 2) * stride_bzn - + offs_k_true * stride_bzk - ) - b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = (b_zp >> b_zp_shifter) & 0xF - b_zp = b_zp.to(tl.float32) - elif has_zp and use_int8_w8a16: - offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = ( - b_zp_ptr - + off_experts * stride_bze - + offs_bn[None, :] * stride_bzn - + offs_k_true * stride_bzk - ) - b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = b_zp.to(tl.float32) - - # We accumulate along the K dimension. - if has_zp: - b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) - else: - b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) - accumulator = tl.dot(a, b, acc=accumulator) - - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - if use_int4_w4a16: - b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk - else: - b_ptrs += BLOCK_SIZE_K * stride_bk - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -@triton.jit -def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - bias_ptr, - c_ptr, - a_scale_ptr, - b_scale_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_bias_e, - stride_bias_n, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - use_fp8_w8a8: tl.constexpr, - use_int8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - per_channel_quant: tl.constexpr, - even_Ks: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using - token and expert matrices. - - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can - be any shape representing batches and K is the feature dimension of - each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is - the number of experts, K is the input feature dimension, and N is - the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the - total number of tokens post padding, topk is the number of times - each token is repeated, and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, - repeated topk times and arranged by the expert index they are - assigned to. - - expert_ids: A tensor containing the indices of the expert for each - block. It determines which expert matrix from B should be used for - each block in A. - - This kernel performs the multiplication of a token by its corresponding - expert matrix as determined by `expert_ids`. The sorting of - `sorted_token_ids` by expert index and padding ensures divisibility by - BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix - multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - offs_token = offs_token.to(tl.int64) - token_mask = offs_token < num_valid_tokens - - off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) - - if off_experts == -1: - # ----------------------------------------------------------- - # Write back zeros to the output when the expert is not - # in the current expert parallel rank. - write_zeros_to_output( - c_ptr, - stride_cm, - stride_cn, - pid_n, - N, - offs_token, - token_mask, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - compute_type, - ) - return - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak - ) - - b_ptrs = ( - b_ptr - + off_experts * stride_be - + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - ) - if bias_ptr is not None: - bias = tl.load( - bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n - ) - if use_int8_w8a16: - b_scale_ptrs = ( - b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn - ) - b_scale = tl.load(b_scale_ptrs) - - if use_fp8_w8a8 or use_int8_w8a8: - # block-wise - if group_k > 0 and group_n > 0: - a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - offs_bsn = offs_bn // group_n - b_scale_ptrs = ( - b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn - ) - # channel-wise - elif per_channel_quant: - b_scale_ptrs = ( - b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn - ) - b_scale = tl.load(b_scale_ptrs) - # Load per-token scale for activations - a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] - # tensor-wise - else: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the - # K dimension. - if even_Ks: - a = tl.load( - a_ptrs, - mask=token_mask[:, None], - other=0.0, - ) - b = tl.load(b_ptrs) - else: - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - - # We accumulate along the K dimension. - if use_int8_w8a16: - accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) - elif use_fp8_w8a8 or use_int8_w8a8: - if group_k > 0 and group_n > 0: - k_start = k * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_scale = tl.load( - a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 - ) - b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - - accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] - else: - if use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) - else: - accumulator += tl.dot(a, b) - else: - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if use_int8_w8a16: - accumulator *= b_scale - elif use_fp8_w8a8 or use_int8_w8a8: - if group_k == 0 or group_n == 0: - accumulator *= a_scale * b_scale - - if bias_ptr is not None: - accumulator += bias - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator *= moe_weight[:, None] - - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Aligns the token distribution across experts to be compatible with block - size for matrix multiplication. - - Parameters: - - topk_ids: A tensor of shape [total_tokens, top_k] representing the - top-k expert indices for each token. - - block_size: The block size used in block matrix multiplication. - - num_experts: The total number of experts. - - Returns: - - sorted_token_ids: A tensor containing the sorted token indices according - to their allocated expert. - - expert_ids: A tensor indicating the assigned expert index for each block. - - num_tokens_post_padded: The total number of tokens after padding, - ensuring divisibility by block_size. - - This function pads the number of tokens that each expert needs to process - so that it is divisible by block_size. - Padding ensures that during block matrix multiplication, the dimensions - align correctly. - - Example: - Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], - block_size = 4, and num_experts = 4: - - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, - with each expert needing to process 3 tokens. - - As block_size is 4, we pad 1 token for each expert. - - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - - Then append padding tokens [12, 12, 12, 12] for each block. - - After sorting by expert index, we obtain token_ids - [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. - Tokens 12 are non-existent (padding) and are ignored in - the subsequent matrix multiplication. - - The padding ensures that the total number of tokens is now divisible - by block_size for proper block matrix operations. - """ - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - - # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total. - cumsum_buffer = torch.empty( - (num_experts + 2,), dtype=torch.int32, device=topk_ids.device - ) - - # Threshold based on benchmark results - fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 - if not fuse_sorted_ids_padding: - sorted_ids.fill_(topk_ids.numel()) - - sgl_moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - fuse_sorted_ids_padding, - ) - return sorted_ids, expert_ids, num_tokens_post_pad - - -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - bias: Optional[torch.Tensor], - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: Dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[List[int]] = None, - no_combine: bool = False, -) -> None: - assert topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - - padded_size = 0 - if use_fp8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation tensor-wise fp8 quantization, dynamic or static - padded_size = padding_size - # activations apply per-token quantization when weights apply per-channel quantization by default - A, A_scale = scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant - ) - else: - # activation block-wise fp8 quantization - assert len(block_shape) == 2 - block_n, block_k = block_shape[0], block_shape[1] - if _is_cuda: - A, A_scale = sglang_per_token_group_quant_fp8(A, block_k) - else: - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation channel-wise int8 quantization - assert ( - per_channel_quant - ), "int8 quantization only supports channel-wise quantization except for block-wise quantization" - A, A_scale = per_token_quant_int8(A) - else: - # activation block-wise int8 quantization - assert len(block_shape) == 2 - block_n, block_k = block_shape[0], block_shape[1] - if _is_cuda: - A, A_scale = sglang_per_token_group_quant_int8(A, block_k) - else: - A, A_scale = per_token_group_quant_int8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - - grid = lambda META: ( - triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) - * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), - ) - - K = B.shape[2] - padded_size - if K % config["BLOCK_SIZE_K"] == 0: - even_Ks = True - else: - even_Ks = False - - if ( - (use_int8_w8a16 or use_int4_w4a16) - and block_shape is not None - and block_shape[1] > 0 - ): - assert B_scale is not None and B_scale.ndim == 3 - assert B_zp is None or B_zp.ndim == 3 - assert bias is None - fused_moe_kernel_gptq_awq[grid]( - A, - B, - C, - B_scale, - B_zp, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - A.shape[1], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0), - B_scale.stride(2), - B_scale.stride(1), - B_zp.stride(0) if B_zp is not None else 0, - B_zp.stride(2) if B_zp is not None else 0, - B_zp.stride(1) if B_zp is not None else 0, - group_size=block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - has_zp=B_zp is not None, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - even_Ks=even_Ks, - **config, - ) - - else: - - fused_moe_kernel[grid]( - A, - B, - bias, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2] - padded_size, - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - bias.stride(0) if bias is not None else 0, - bias.stride(1) if bias is not None else 0, - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - per_channel_quant=per_channel_quant, - even_Ks=even_Ks, - **config, - ) - - -def get_config_file_name( - E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None -) -> str: - device_name = get_device_name().replace(" ", "_") - dtype_selector = "" if not dtype else f",dtype={dtype}" - block_shape_selector = ( - "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" - ) - return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" - - -@functools.lru_cache -def get_moe_configs( - E: int, - N: int, - dtype: Optional[str], - block_n: Optional[int] = 0, - block_k: Optional[int] = 0, -) -> Optional[Dict[int, Any]]: - """ - Return optimized configurations for the fused MoE kernel. - - The return value will be a dictionary that maps an irregular grid of - batch sizes to configurations of the fused_moe kernel. To evaluate the - kernel on a given batch size bs, the closest batch size in the grid should - be picked and the associated configuration chosen to invoke the kernel. - """ - # Supported Triton versions, should be sorted from the newest to the oldest - supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"] - - # First look up if an optimized configuration is available in the configs - # directory - json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k]) - - # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, - # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance. - triton_version = triton.__version__ - version_dir = f"triton_{triton_version.replace('.', '_')}" - config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "configs", - version_dir, - json_file_name, - ) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - # Please note that although we find the config files, performance might still be suboptimal. - # This is because the tuning environment might differ from your current environment. - # For example, updating the Triton version might cause all old configs to become suboptimal. - # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment. - # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton - logger.info(f"Using MoE kernel config from {config_file_path}.") - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} - - # Searching for other triton versions that supports the same config - for try_triton_version in supported_triton_versions: - if try_triton_version == triton_version: - continue - try_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "configs", - f"triton_{try_triton_version.replace('.', '_')}", - json_file_name, - ) - if os.path.exists(try_config_file_path): - with open(try_config_file_path) as f: - logger.warning( - f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!", - ) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} - - # If no optimized configuration is available, we will use the default - # configuration - logger.warning( - ( - "Using default MoE kernel config. Performance might be sub-optimal! " - "Config file not found at %s, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton" - ), - config_file_path, - ) - return None - - -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], - is_marlin: bool, - block_shape: Optional[List[int]] = None, -) -> Dict[str, int]: - if dtype == "fp8_w8a8": - if block_shape is None: - config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 if _is_hip else 4, - } - if M <= E: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 if _is_hip else 4, - } - else: - # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_shape[0], - "BLOCK_SIZE_K": block_shape[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2 if _is_hip else 3, - } - else: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - } - return config - - -def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - is_marlin: bool = False, - block_shape: Optional[List[int]] = None, -): - from sglang.srt.layers.moe.fused_moe_triton import get_config - - override_config = get_config() - if override_config: - config = override_config - else: - # First try to load optimal config from the file - E, _, N = w2_shape - block_n = block_shape[0] if block_shape else 0 - block_k = block_shape[1] if block_shape else 0 - configs = get_moe_configs(E, N, dtype, block_n, block_k) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape - ) - return config - - -def get_config_dtype_str( - dtype: torch.dtype, - use_int8_w8a16: Optional[bool] = False, - use_int4_w4a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False, - use_int8_w8a8: Optional[bool] = False, -): - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a8: - return "int8_w8a8" - elif use_int4_w4a16: - return "int4_w4a16" - elif use_int8_w8a16: - return "int8_w8a16" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None - - def inplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -1276,92 +319,6 @@ def fused_experts( ) -# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py -@triton.jit -def _moe_sum_reduce_kernel( - input_ptr, - input_stride_0, - input_stride_1, - input_stride_2, - output_ptr, - output_stride_0, - output_stride_1, - token_num: int, - topk_num: int, - hidden_dim: int, - routed_scaling_factor: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DIM: tl.constexpr, - NUM_STAGE: tl.constexpr, -): - input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) - input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) - output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) - - token_block_id = tl.program_id(0) - dim_block_id = tl.program_id(1) - - token_start = token_block_id * BLOCK_M - token_end = min((token_block_id + 1) * BLOCK_M, token_num) - - dim_start = dim_block_id * BLOCK_DIM - dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) - - offs_dim = dim_start + tl.arange(0, BLOCK_DIM) - - for token_index in range(token_start, token_end): - accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) - input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim - for i in tl.range(0, topk_num, num_stages=NUM_STAGE): - tmp = tl.load( - input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 - ) - accumulator += tmp - accumulator = accumulator * routed_scaling_factor - store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim - tl.store( - store_t_ptr, - accumulator.to(input_ptr.dtype.element_ty), - mask=offs_dim < dim_end, - ) - - -def moe_sum_reduce_triton( - input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float -): - assert input.is_contiguous() - assert output.is_contiguous() - - token_num, topk_num, hidden_dim = input.shape - assert output.shape[0] == token_num and output.shape[1] == hidden_dim - - BLOCK_M = 1 - BLOCK_DIM = 2048 - NUM_STAGE = 1 - num_warps = 8 - - grid = ( - triton.cdiv(token_num, BLOCK_M), - triton.cdiv(hidden_dim, BLOCK_DIM), - ) - - _moe_sum_reduce_kernel[grid]( - input, - *input.stride(), - output, - *output.stride(), - token_num=token_num, - topk_num=topk_num, - hidden_dim=hidden_dim, - routed_scaling_factor=routed_scaling_factor, - BLOCK_M=BLOCK_M, - BLOCK_DIM=BLOCK_DIM, - NUM_STAGE=NUM_STAGE, - num_warps=num_warps, - ) - return - - @torch.compile def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): torch.sum(x, dim=1, out=out) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py new file mode 100644 index 000000000..51114aade --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import functools +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton + +from sglang.srt.utils import get_device_name, is_hip + +logger = logging.getLogger(__name__) +_is_hip = is_hip() + + +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None +) -> str: + device_name = get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" + + +@functools.lru_cache +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = 0, + block_k: Optional[int] = 0, +) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + # Supported Triton versions, should be sorted from the newest to the oldest + supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"] + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k]) + + # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, + # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance. + triton_version = triton.__version__ + version_dir = f"triton_{triton_version.replace('.', '_')}" + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + version_dir, + json_file_name, + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + # Please note that although we find the config files, performance might still be suboptimal. + # This is because the tuning environment might differ from your current environment. + # For example, updating the Triton version might cause all old configs to become suboptimal. + # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment. + # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton + logger.info(f"Using MoE kernel config from {config_file_path}.") + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # Searching for other triton versions that supports the same config + for try_triton_version in supported_triton_versions: + if try_triton_version == triton_version: + continue + try_config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + f"triton_{try_triton_version.replace('.', '_')}", + json_file_name, + ) + if os.path.exists(try_config_file_path): + with open(try_config_file_path) as f: + logger.warning( + f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!", + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ( + "Using default MoE kernel config. Performance might be sub-optimal! " + "Config file not found at %s, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton" + ), + config_file_path, + ) + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, + block_shape: Optional[List[int]] = None, +) -> Dict[str, int]: + if dtype == "fp8_w8a8": + if block_shape is None: + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 if _is_hip else 4, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 if _is_hip else 4, + } + else: + # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 if _is_hip else 3, + } + else: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, + block_shape: Optional[List[int]] = None, +): + from sglang.srt.layers.moe.fused_moe_triton import get_config + + override_config = get_config() + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) + return config + + +def get_config_dtype_str( + dtype: torch.dtype, + use_int8_w8a16: Optional[bool] = False, + use_int4_w4a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False, + use_int8_w8a8: Optional[bool] = False, +): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a8: + return "int8_w8a8" + elif use_int4_w4a16: + return "int4_w4a16" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py new file mode 100644 index 000000000..94f356e28 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -0,0 +1,796 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + scaled_fp8_quant, + sglang_per_token_group_quant_fp8, +) +from sglang.srt.layers.quantization.int8_kernel import ( + per_token_group_quant_int8, + per_token_quant_int8, + sglang_per_token_group_quant_int8, +) +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + is_cpu, + is_cuda, + is_hip, +) + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _is_cuda: + pass +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_hip: + pass + +padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 + + +@triton.jit +def write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, +): + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + if use_int4_w4a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not even_Ks: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = (b_zp >> b_zp_shifter) & 0xF + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + bias_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_bias_e, + stride_bias_n, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, + even_Ks: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + offs_token = offs_token.to(tl.int64) + token_mask = offs_token < num_valid_tokens + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if off_experts == -1: + # ----------------------------------------------------------- + # Write back zeros to the output when the expert is not + # in the current expert parallel rank. + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) + return + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) + if bias_ptr is not None: + bias = tl.load( + bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n + ) + if use_int8_w8a16: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8 or use_int8_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + # channel-wise + elif per_channel_quant: + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + if use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if use_int8_w8a16: + accumulator *= b_scale + elif use_fp8_w8a8 or use_int8_w8a8: + if group_k == 0 or group_n == 0: + accumulator *= a_scale * b_scale + + if bias_ptr is not None: + accumulator += bias + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator *= moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor], + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, + no_combine: bool = False, +) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + padded_size = 0 + if use_fp8_w8a8: + assert B_scale is not None + if block_shape is None: + # activation tensor-wise fp8 quantization, dynamic or static + padded_size = padding_size + # activations apply per-token quantization when weights apply per-channel quantization by default + A, A_scale = scaled_fp8_quant( + A, A_scale, use_per_token_if_dynamic=per_channel_quant + ) + else: + # activation block-wise fp8 quantization + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + if _is_cuda: + A, A_scale = sglang_per_token_group_quant_fp8(A, block_k) + else: + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a8: + assert B_scale is not None + if block_shape is None: + # activation channel-wise int8 quantization + assert ( + per_channel_quant + ), "int8 quantization only supports channel-wise quantization except for block-wise quantization" + A, A_scale = per_token_quant_int8(A) + else: + # activation block-wise int8 quantization + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + if _is_cuda: + A, A_scale = sglang_per_token_group_quant_int8(A, block_k) + else: + A, A_scale = per_token_group_quant_int8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + grid = lambda META: ( + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) + * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), + ) + + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + assert bias is None + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, + **config, + ) + + else: + + fused_moe_kernel[grid]( + A, + B, + bias, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padded_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + bias.stride(0) if bias is not None else 0, + bias.stride(1) if bias is not None else 0, + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + even_Ks=even_Ks, + **config, + ) + + +# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py +@triton.jit +def _moe_sum_reduce_kernel( + input_ptr, + input_stride_0, + input_stride_1, + input_stride_2, + output_ptr, + output_stride_0, + output_stride_1, + token_num: int, + topk_num: int, + hidden_dim: int, + routed_scaling_factor: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DIM: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64) + input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64) + output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64) + + token_block_id = tl.program_id(0) + dim_block_id = tl.program_id(1) + + token_start = token_block_id * BLOCK_M + token_end = min((token_block_id + 1) * BLOCK_M, token_num) + + dim_start = dim_block_id * BLOCK_DIM + dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim) + + offs_dim = dim_start + tl.arange(0, BLOCK_DIM) + + for token_index in range(token_start, token_end): + accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32) + input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim + for i in tl.range(0, topk_num, num_stages=NUM_STAGE): + tmp = tl.load( + input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0 + ) + accumulator += tmp + accumulator = accumulator * routed_scaling_factor + store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim + tl.store( + store_t_ptr, + accumulator.to(input_ptr.dtype.element_ty), + mask=offs_dim < dim_end, + ) + + +def moe_sum_reduce_triton( + input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float +): + assert input.is_contiguous() + assert output.is_contiguous() + + token_num, topk_num, hidden_dim = input.shape + assert output.shape[0] == token_num and output.shape[1] == hidden_dim + + BLOCK_M = 1 + BLOCK_DIM = 2048 + NUM_STAGE = 1 + num_warps = 8 + + grid = ( + triton.cdiv(token_num, BLOCK_M), + triton.cdiv(hidden_dim, BLOCK_DIM), + ) + + _moe_sum_reduce_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + token_num=token_num, + topk_num=topk_num, + hidden_dim=hidden_dim, + routed_scaling_factor=routed_scaling_factor, + BLOCK_M=BLOCK_M, + BLOCK_DIM=BLOCK_DIM, + NUM_STAGE=NUM_STAGE, + num_warps=num_warps, + ) + return diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py new file mode 100644 index 000000000..64d0126d6 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Tuple + +import torch +import triton + +from sglang.srt.utils import is_cuda, is_hip + +_is_cuda = is_cuda() +_is_hip = is_hip() + +if _is_cuda or _is_hip: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, num_experts: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + + # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total. + cumsum_buffer = torch.empty( + (num_experts + 2,), dtype=torch.int32, device=topk_ids.device + ) + + # Threshold based on benchmark results + fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 + if not fuse_sorted_ids_padding: + sorted_ids.fill_(topk_ids.numel()) + + sgl_moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + fuse_sorted_ids_padding, + ) + return sorted_ids, expert_ids, num_tokens_post_pad