# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE Triton kernels.""" import functools import json import os # torch.compile needs typing.List. It will fail torch.library.infer_schema # otherwise from typing import List # noqa: UP035 from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger # yapf: disable from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, run_cutlass_block_scaled_fused_experts) # yapf: enable from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, activation_without_mul, moe_kernel_quantize_input) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( dequant_mxfp4) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type): accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel_gptq_awq( # Pointers to matrices a_ptr, b_ptr, c_ptr, b_scale_ptr, b_zp_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N: tl.constexpr, K: tl.constexpr, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, stride_bze, stride_bzk, stride_bzn, block_k_diviable: tl.constexpr, group_size: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, has_zp: tl.constexpr, use_int4_w4a16: tl.constexpr, use_int8_w8a16: tl.constexpr): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) if off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type) return offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) if use_int4_w4a16: b_ptrs = b_ptr + off_experts * stride_be + \ (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ stride_bn b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: b_ptrs = b_ptr + off_experts * stride_be + \ offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn if not has_zp and use_int4_w4a16: b_zp_num = 8 if not has_zp and use_int8_w8a16: b_zp_num = 128 elif has_zp and use_int4_w4a16: b_zp_shifter = (offs_bn[None, :] % 2) * 4 # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. if not block_k_diviable: k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K k_other = 0.0 else: k_mask = None k_other = None a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ offs_bn[None, :] * stride_bsn + \ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ stride_bsk b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ (offs_bn[None, :] // 2) * stride_bzn + \ offs_k_true * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = ((b_zp >> b_zp_shifter) & 0xF) b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ offs_bn[None, :] * stride_bzn + \ offs_k_true * stride_bzk b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) # We accumulate along the K dimension. if has_zp: b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) else: b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) accumulator = tl.dot(a, b, acc=accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak if use_int4_w4a16: b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk else: b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, b_bias_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn, stride_bbe, # bias expert stride stride_bbn, # bias N stride # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, HAS_BIAS: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) if off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type) return offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) if use_int8_w8a16: b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8 or use_int8_w8a8: # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn) # channel-wise elif per_channel_quant: b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ None, :] * stride_bsn b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] # tensor-wise else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) if HAS_BIAS: # bias shape: [num_experts, N] bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load(a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: if use_fp8_w8a8: # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if HAS_BIAS: accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) else: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], B_zp: Optional[torch.Tensor], topk_weights: Optional[torch.Tensor], sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, block_shape: Optional[list[int]] = None, B_bias: Optional[torch.Tensor] = None) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None assert (block_shape is None or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) assert (block_shape is None or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None M = A.size(0) num_tokens = M * top_k EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config['BLOCK_SIZE_M']) grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.size(1), META['BLOCK_SIZE_N']), ) HAS_BIAS = B_bias is not None if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8) config = config.copy() config.update( get_moe_wna16_block_config(config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, size_k=A.size(1), size_n=B.size(1), num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"])) if use_moe_wna16_cuda: bit = 4 if use_int4_w4a16 else 8 ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, topk_weights if mul_routed_weight else None, sorted_token_ids, expert_ids, num_tokens_post_padded, top_k, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], bit) return fused_moe_kernel_gptq_awq[grid]( A, B, C, B_scale, B_zp, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.size(1), A.size(1), EM, num_tokens, A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), B_scale.stride(0), B_scale.stride(2), B_scale.stride(1), B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, has_zp=B_zp is not None, use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, **config, ) else: config = config.copy() BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) fused_moe_kernel[grid]( A, B, C, B_bias, A_scale, B_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.size(1), B.size(2), EM, num_tokens, A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, B_bias.stride(0) if B_bias is not None else 0, B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @triton.jit def compute_identity_kernel( top_k: int, hidden_states_ptr: tl.tensor, expert_scales_ptr: tl.tensor, num_tokens: int, output_ptr: tl.tensor, hidden_dim: int, scales_stride: int, BLOCK_SIZE: tl.constexpr, ) -> None: pid = tl.program_id(0) batch_id = pid // (hidden_dim // BLOCK_SIZE) dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE if batch_id >= num_tokens or dim_offset >= hidden_dim: return h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim) result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for i in range(top_k): scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) result += h * scale tl.store(output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), result, mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim) def zero_experts_compute_triton(expert_indices: torch.Tensor, expert_scales: torch.Tensor, num_experts: int, zero_expert_type: str, hidden_states: torch.Tensor) -> torch.Tensor: N = expert_indices.numel() top_k = expert_indices.size(-1) grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts zero_expert_scales = expert_scales.clone() zero_expert_scales[zero_expert_mask] = 0.0 normal_expert_mask = expert_indices >= num_experts expert_indices[normal_expert_mask] = 0 expert_scales[normal_expert_mask] = 0.0 output = torch.zeros_like(hidden_states).to(hidden_states.device) hidden_dim = hidden_states.size(-1) num_tokens = hidden_states.size(0) grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), ) compute_identity_kernel[grid]( top_k, hidden_states, zero_expert_scales, num_tokens, output, hidden_dim, zero_expert_scales.stride(0), BLOCK_SIZE=256, ) return output # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name(E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" block_shape_selector = ("" if not block_shape or not all(block_shape) else f",block_shape={block_shape}").replace(" ", "") return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 # Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache def get_moe_configs( E: int, N: int, dtype: Optional[str], block_n: Optional[int] = None, block_k: Optional[int] = None, ) -> Optional[dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. The return value will be a dictionary that maps an irregular grid of batch sizes to configurations of the fused_moe kernel. To evaluate the kernel on a given batch size bs, the closest batch size in the grid should be picked and the associated configuration chosen to invoke the kernel. """ # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_paths = [] # note that we prioritize user defined config user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: user_defined_config_file_path = os.path.join( user_defined_config_folder, json_file_name) config_file_paths.append(user_defined_config_file_path) default_config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it tuned_config = json.load(f) # Delete triton_version from tuned_config tuned_config.pop("triton_version", None) return {int(key): val for key, val in tuned_config.items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( ("Using default MoE config. Performance might be sub-optimal! " "Config file not found at %s"), config_file_paths) return None def get_moe_wna16_block_config(config: dict[str, int], use_moe_wna16_cuda: bool, num_valid_tokens: int, size_k: int, size_n: int, num_experts: int, group_size: int, real_top_k: int, block_size_m: int): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: # optimal block config is set return {} if not use_moe_wna16_cuda: # triton moe wna16 kernel if num_valid_tokens // real_top_k == 1: # if bs=1, use a smaller BLOCK_SIZE_N return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} else: return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} else: # cuda moe wna16 kernel # set default block_size 128, and increase them when num_blocks # is too large. block_size_n = 128 block_size_k = 128 if block_size_k <= group_size: block_size_k = group_size num_n_blocks = size_k // block_size_k num_k_blocks = size_n // block_size_k num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ num_experts if num_valid_tokens // real_top_k <= block_size_m: num_m_blocks = min(num_m_blocks, num_valid_tokens) num_blocks = num_m_blocks * num_n_blocks * num_k_blocks if size_k % 256 == 0 and num_blocks >= 256 and \ block_size_k < 256: block_size_k = 256 num_blocks = num_blocks // (256 // block_size_k) if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ num_blocks >= 512: block_size_k = block_size_k * 2 num_blocks = num_blocks // 2 if num_blocks > 1024: block_size_n = 256 num_n_blocks = num_n_blocks // 2 num_blocks = num_blocks // 2 if size_n <= 1024 and num_blocks >= 1024: # The kernel performance got much better with BLOCK_SIZE_N=1024 # when num_blocks is large, event when N is small. # Not sure why, maybe it force the CUDA SM process only one block # at the same time. block_size_n = 1024 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, num_experts: int, bit: int): return current_platform.is_cuda() and bit == 4 and \ group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 def get_default_config( M: int, E: int, N: int, K: int, topk: int, dtype: Optional[str], block_shape: Optional[list[int]] = None, ) -> dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] # num_stages=3 can cause triton.runtime.errors.OutOfResources # on ROCm, set it to 2 instead. config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 3 if not current_platform.is_rocm() else 2, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M)} elif M <= 20: config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} elif M <= 40: config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} elif M <= E: config = { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, } else: config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, } return config def try_get_optimal_moe_config( w1_shape: tuple[int, ...], w2_shape: tuple[int, ...], top_k: int, dtype: Optional[str], M: int, block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape if dtype == "int4_w4a16": N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the # optimal config config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) return config def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, topk_indices, token_expert_indices, gating_output, ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax return rocm_aiter_topk_softmax return vllm_topk_softmax def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") M, _ = hidden_states.size() topk_weights = torch.empty(M, topk, dtype=torch.float32, device=hidden_states.device) topk_ids = torch.empty( M, topk, dtype=torch.int32 if indices_type is None else indices_type, device=hidden_states.device) token_expert_indices = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. topk_func = dispatch_topk_func() topk_weights, topk_ids = topk_func(topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize) return topk_weights, topk_ids, token_expert_indices def fused_topk_bias( hidden_states: torch.Tensor, gating_output: torch.Tensor, e_score_correction_bias: torch.Tensor, topk: int, renormalize: bool, ): n_routed_experts = gating_output.shape[-1] scores = gating_output.softmax(dim=-1) scores_for_choice = scores.view( -1, n_routed_experts) + e_score_correction_bias.unsqueeze(0) topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights.to(torch.float32), topk_indices.to(torch.int32) # This is used by the Deepseek-V2 and Deepseek-V3 model @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ current_platform.is_cuda() and \ num_expert_group <= 32 and topk <= 32 and \ e_score_correction_bias is not None: return fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, topk=topk, renormalize=renormalize, e_score_correction_bias=e_score_correction_bias, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor) assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) group_scores = (scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)) else: group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = group_mask.unsqueeze(-1).expand( num_token, num_expert_group, scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) if routed_scaling_factor != 1.0: topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( topk_ids: torch.Tensor, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, indices_type: Optional[torch.dtype] = None) -> torch.Tensor: ''' Map the logical expert ids to physical expert ids and record the expert load metrics. This will select a pseudo-random replica for each logical expert. Only used for EPLB. Args: topk_ids: The logical expert ids. expert_load_view: The expert load view. logical_to_physical_map: The logical to physical map. logical_replica_count: The logical replica count. indices_type: The indices type. Returns: The physical expert ids. ''' # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert # In case `indices_type` is not `torch.long` or `torch.int`, # e.g. `torch.uint32` as required by dispatch/combine kernels topk_ids_long = topk_ids.long() # Use (token position) modulo (replica count) # to deterministically choose a replica replica_count = logical_replica_count[topk_ids_long] # Flatten-position based index, reshaped back to `topk_ids` shape pos_indices = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.long).reshape_as(topk_ids) # Compute pseudo-random indices by modulo replica_indices = (pos_indices % replica_count).unsqueeze(-1) physical_ids = logical_to_physical_map[topk_ids_long].gather( -1, replica_indices).squeeze(-1) topk_ids = physical_ids # 2. Record expert load metrics. # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since # `FusedMoEPrepareAndFinalize` will return the expert # token count, in some cases directly from the kernel. # However, now there are many code paths not using # the modular kernel, e.g. calling `fused_experts`, # so we decide to keep the logic here. # # If later refactor moved all the MoE kernel calls # to the modular kernel, we can move this logic there # to achieve better efficiency. # `expert_load_view`: (num_physical_experts,) # `torch.bincount` is not compilable, so use `scatter_add_` instead. topk_ids_flatten = topk_ids.flatten() expert_load_view.scatter_add_( dim=0, index=topk_ids_flatten.long(), src=torch.ones_like(topk_ids_flatten).to(expert_load_view)) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) return topk_ids def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, e_score_correction_bias: torch.Tensor, num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), ( "Number of tokens mismatch") if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) topk_values, topk_indices = ops.grouped_topk( scores, scores_with_bias.to(scores.dtype), num_expert_group, topk_group, topk, renormalize, routed_scaling_factor) return topk_values.to(torch.float32), topk_indices.to(torch.int32) def inplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, #noqa: UP006 w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def inplace_fused_experts_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, #noqa: UP006 w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: pass direct_register_custom_op( op_name="inplace_fused_experts", op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, tags=(() if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order, )), ) def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, #noqa: UP006 w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias) def outplace_fused_experts_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) direct_register_custom_op( op_name="outplace_fused_experts", op_func=outplace_fused_experts, fake_impl=outplace_fused_experts_fake, tags=(() if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order, )), ) def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: torch.ops.vllm.inplace_fused_experts(**kwargs) hidden_states = kwargs['hidden_states'] return hidden_states def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: return torch.ops.vllm.outplace_fused_experts(**kwargs) def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: if inplace: return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, quant_config: Optional[FusedMoEQuantConfig] = None, allow_deep_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG use_fp8_w8a8 = quant_config.use_fp8_w8a8 # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. if (allow_deep_gemm and quant_config.use_fp8_w8a8 and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert quant_config is not None assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, a1_scale=quant_config.a1_scale, a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and _valid_cutlass_block_scaled_grouped_gemm( w1, w2, inplace, activation, apply_router_weight_on_input, expert_map)): assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, w2=w2, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, topk_weights=topk_weights, topk_ids=topk_ids) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=quant_config.use_fp8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a16=quant_config.use_int8_w8a16, use_int4_w4a16=quant_config.use_int4_w4a16, use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4, per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, w1_zp=quant_config.w1_zp, w2_zp=quant_config.w2_zp, a1_scale=quant_config.a1_scale, a2_scale=quant_config.a2_scale, block_shape=quant_config.block_shape, w1_bias=quant_config.w1_bias, w2_bias=quant_config.w2_bias) SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") def _get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, use_mxfp4_w4a4: bool, ) -> Union[None, torch.dtype, str]: """ Get the quantization type based on the quantization strategy flags. We don't have a quant_config at this point so we need to work backwards. A return type of None means no quantization is required because the input is unquantized or has been quantized prior to calling fused_experts_impl. """ if use_fp8_w8a8: return torch.float8_e4m3fn elif use_int8_w8a8: return torch.int8 elif use_mxfp4_w4a4: return "mxfp4" return None def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), ( "Hidden size mismatch") elif use_mxfp4_w4a4: # 16bit activation and fp4x2 packed weight assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" else: assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] num_tokens = hidden_states.size(0) E, N, _ = w1.size() K = w2.size(1) if global_num_experts == -1: global_num_experts = E top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, use_mxfp4_w4a4=use_mxfp4_w4a4, dtype=hidden_states.dtype) # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are # quantized prior to calling fused_experts. quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_mxfp4_w4a4=use_mxfp4_w4a4) get_config_func = functools.partial( try_get_optimal_moe_config, w1.size(), w2.size(), top_k_num, config_dtype, block_shape=block_shape, ) config = get_config_func(M) # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, K), device=hidden_states.device, dtype=hidden_states.dtype) intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) if use_mxfp4_w4a4: # Weight has to be dequantized for mxfp4 emulation. w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1_scale = None w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) w2_scale = None for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.size(1)] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, a1q_scale, w1_scale, w1_zp, curr_topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, apply_router_weight_on_input, top_k_num, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w1_bias) # Activation function with multiplication if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) elif activation == "gelu": torch.ops._C.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 torch.ops._C.swigluoai_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) # Activation function without multiplication elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, a2q_scale, w2_scale, w2_zp, curr_topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, not apply_router_weight_on_input, 1, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w2_bias) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, quant_config: FusedMoEQuantConfig, ): super().__init__(quant_config) @property def activation_formats( self ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: return (mk.FusedMoEActivationFormat.Standard, mk.FusedMoEActivationFormat.Standard) def supports_chunking(self) -> bool: return True def supports_expert_map(self) -> bool: return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() def workspace_shapes( self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) return (workspace1, workspace2, output, a.dtype) def apply( self, output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): # Check constraints. if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( "Hidden size mismatch") else: assert hidden_states.size(-1) == w1.size(2), \ (f"Hidden size mismatch {hidden_states.size(-1)} " f"!= {w1.size(2)}") assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" assert hidden_states.dim() == 2 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] E, num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) if global_num_experts == -1: global_num_experts = E config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") # Note that the output tensor might be in workspace1 intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace13, (num_tokens * top_k_num, N // 2)) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, a1q_scale, self.w1_scale, self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, num_tokens_post_padded, False, # mul_routed_weights top_k_num, config, compute_type=compute_type, use_fp8_w8a8=self.quant_config.use_fp8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w1_bias, ) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( intermediate_cache2, a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape) invoke_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, self.w2_scale, self.w2_zp, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, not apply_router_weight_on_input, 1, config, compute_type=compute_type, use_fp8_w8a8=self.quant_config.use_fp8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w2_bias, ) ops.moe_sum(intermediate_cache3, output) def modular_triton_fused_moe( quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), )