# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE Triton kernels.""" import functools import json import os from collections.abc import Callable from typing import Any import torch import torch.nn.functional as F import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str, ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, run_cutlass_block_scaled_fused_experts, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8, ) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, activation_without_mul, disable_inplace, moe_kernel_quantize_input, ) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) @triton.jit def write_zeros_to_output( c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type, ): accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel_gptq_awq( # Pointers to matrices a_ptr, b_ptr, c_ptr, b_scale_ptr, b_zp_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N: tl.constexpr, K: tl.constexpr, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_bse, stride_bsk, stride_bsn, stride_bze, stride_bzk, stride_bzn, block_k_diviable: tl.constexpr, group_size: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, has_zp: tl.constexpr, use_int4_w4a16: tl.constexpr, use_int8_w8a16: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) if off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. write_zeros_to_output( c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type, ) return offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) if use_int4_w4a16: b_ptrs = ( b_ptr + off_experts * stride_be + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn ) b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: b_ptrs = ( b_ptr + off_experts * stride_be + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) if not has_zp and use_int4_w4a16: b_zp_num = 8 if not has_zp and use_int8_w8a16: b_zp_num = 128 elif has_zp and use_int4_w4a16: b_zp_shifter = (offs_bn[None, :] % 2) * 4 # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. if not block_k_diviable: k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K k_other = 0.0 else: k_mask = None k_other = None a = tl.load( a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ) b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk ) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = ( b_zp_ptr + off_experts * stride_bze + (offs_bn[None, :] // 2) * stride_bzn + offs_k_true * stride_bzk ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = (b_zp >> b_zp_shifter) & 0xF b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size b_zp_ptrs = ( b_zp_ptr + off_experts * stride_bze + offs_bn[None, :] * stride_bzn + offs_k_true * stride_bzk ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) # We accumulate along the K dimension. if has_zp: b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) else: b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) accumulator = tl.dot(a, b, acc=accumulator) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak if use_int4_w4a16: b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk else: b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, b_bias_ptr, a_scale_ptr, b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down # (A has M rows). stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn, stride_bbe, # bias expert stride stride_bbn, # bias N stride # Block size for block-wise quantization group_n: tl.constexpr, group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, per_channel_quant: tl.constexpr, HAS_BIAS: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, and N is the output feature dimension. - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) if off_experts == -1: # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. write_zeros_to_output( c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, compute_type, ) return offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak ) b_ptrs = ( b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) ) if use_int8_w8a16: b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn ) b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8 or use_int8_w8a8: # block-wise if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn ) # channel-wise elif per_channel_quant: b_scale_ptrs = ( b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn ) b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] # tensor-wise else: a_scale = tl.load(a_scale_ptr) b_scale = tl.load(b_scale_ptr + off_experts) if HAS_BIAS: # bias shape: [num_experts, N] bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load( a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_scale = tl.load( a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: if use_fp8_w8a8: # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if HAS_BIAS: accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) else: accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) def invoke_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A_scale: torch.Tensor | None, B_scale: torch.Tensor | None, B_zp: torch.Tensor | None, topk_weights: torch.Tensor | None, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, config: dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, per_channel_quant: bool, block_shape: list[int] | None = None, B_bias: torch.Tensor | None = None, ) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None assert block_shape is None or triton.cdiv( B.size(-2), block_shape[0] ) == B_scale.size(-2) assert block_shape is None or triton.cdiv( B.size(-1), block_shape[1] ) == B_scale.size(-1) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None M = A.size(0) num_tokens = M * top_k EM = sorted_token_ids.size(0) if A.size(0) < config["BLOCK_SIZE_M"]: # optimize for small batch_size. # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), ) HAS_BIAS = B_bias is not None if ( (use_int8_w8a16 or use_int4_w4a16) and block_shape is not None and block_shape[1] > 0 ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, group_size=block_shape[1], num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8, ) config = config.copy() config.update( get_moe_wna16_block_config( config=config, use_moe_wna16_cuda=use_moe_wna16_cuda, num_valid_tokens=num_tokens, size_k=A.size(1), size_n=B.size(1), num_experts=B.size(1), group_size=block_shape[1], real_top_k=top_k, block_size_m=config["BLOCK_SIZE_M"], ) ) if use_moe_wna16_cuda: bit = 4 if use_int4_w4a16 else 8 ops.moe_wna16_gemm( A, C, B, B_scale, B_zp, topk_weights if mul_routed_weight else None, sorted_token_ids, expert_ids, num_tokens_post_padded, top_k, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], bit, ) return fused_moe_kernel_gptq_awq[grid]( A, B, C, B_scale, B_zp, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.size(1), A.size(1), EM, num_tokens, A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), B_scale.stride(0), B_scale.stride(2), B_scale.stride(1), B_zp.stride(0) if B_zp is not None else 0, B_zp.stride(2) if B_zp is not None else 0, B_zp.stride(1) if B_zp is not None else 0, block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, group_size=block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, has_zp=B_zp is not None, use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, **config, ) else: config = config.copy() config["SPLIT_K"] = 1 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) fused_moe_kernel[grid]( A, B, C, B_bias, A_scale, B_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, B.size(1), B.size(2), EM, num_tokens, A.stride(0), A.stride(1), B.stride(0), B.stride(2), B.stride(1), C.stride(1), C.stride(2), A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, B_bias.stride(0) if B_bias is not None else 0, B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, per_channel_quant=per_channel_quant, HAS_BIAS=HAS_BIAS, BLOCK_SIZE_K=BLOCK_SIZE_K, **config, ) @triton.jit def compute_identity_kernel( top_k: int, hidden_states_ptr: tl.tensor, expert_scales_ptr: tl.tensor, num_tokens: int, output_ptr: tl.tensor, hidden_dim: int, scales_stride: int, BLOCK_SIZE: tl.constexpr, ) -> None: pid = tl.program_id(0) batch_id = pid // (hidden_dim // BLOCK_SIZE) dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE if batch_id >= num_tokens or dim_offset >= hidden_dim: return h = tl.load( hidden_states_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, ) result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for i in range(top_k): scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) result += h * scale tl.store( output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), result, mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, ) def zero_experts_compute_triton( expert_indices: torch.Tensor, expert_scales: torch.Tensor, num_experts: int, zero_expert_type: str, hidden_states: torch.Tensor, ) -> torch.Tensor: N = expert_indices.numel() top_k = expert_indices.size(-1) grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts zero_expert_scales = expert_scales.clone() zero_expert_scales[zero_expert_mask] = 0.0 normal_expert_mask = expert_indices >= num_experts expert_indices[normal_expert_mask] = 0 expert_scales[normal_expert_mask] = 0.0 output = torch.zeros_like(hidden_states).to(hidden_states.device) hidden_dim = hidden_states.size(-1) num_tokens = hidden_states.size(0) grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),) compute_identity_kernel[grid]( top_k, hidden_states, zero_expert_scales, num_tokens, output, hidden_dim, zero_expert_scales.stride(0), BLOCK_SIZE=256, ) return output # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name( E: int, N: int, dtype: str | None, block_shape: list[int] | None = None ) -> str: device_name = current_platform.get_device_name().replace(" ", "_") # Set device_name to H200 if a device from the H200 family is detected if "H200" in device_name.split("_"): device_name = "NVIDIA_H200" dtype_selector = "" if not dtype else f",dtype={dtype}" block_shape_selector = ( "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" ).replace(" ", "") return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 # Adapted from: https://github.com/sgl-project/sglang/pull/2628 @functools.lru_cache def get_moe_configs( E: int, N: int, dtype: str | None, block_n: int | None = None, block_k: int | None = None, ) -> dict[int, Any] | None: """ Return optimized configurations for the fused MoE kernel. The return value will be a dictionary that maps an irregular grid of batch sizes to configurations of the fused_moe kernel. To evaluate the kernel on a given batch size bs, the closest batch size in the grid should be picked and the associated configuration chosen to invoke the kernel. """ # Avoid optimizing for the batch invariant case. Use default config if vllm_is_batch_invariant(): return None # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None json_file_name = get_config_file_name(E, N, dtype, block_shape) config_file_paths = [] # note that we prioritize user defined config user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: user_defined_config_file_path = os.path.join( user_defined_config_folder, json_file_name ) config_file_paths.append(user_defined_config_file_path) default_config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name ) config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info_once( "Using configuration from %s for MoE layer.", config_file_path, scope="global", ) # If a configuration has been found, return it tuned_config = json.load(f) # Delete triton_version from tuned_config tuned_config.pop("triton_version", None) return {int(key): val for key, val in tuned_config.items()} # If no optimized configuration is available, we will use the default # configuration logger.warning_once( "Using default MoE config. Performance might be sub-optimal! " "Config file not found at %s", ", ".join(config_file_paths), scope="local", ) return None def _ensure_block_size_k_divisible( size_k: int, block_size_k: int, group_size: int ) -> int: """Ensure block_size_k is a divisor of size_k and divisible by group_size. This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0. Args: size_k: The size_k dimension that must be divisible by result. block_size_k: Preferred block size (will be adjusted if needed). group_size: The result must be divisible by this. Returns: A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size. """ # Fast path: already valid if size_k % block_size_k == 0 and block_size_k % group_size == 0: return block_size_k # Find the largest value that: # 1. Divides size_k (size_k % candidate == 0) # 2. Is divisible by group_size (candidate % group_size == 0) # 3. Is <= block_size_k (prefer smaller values close to block_size_k) # # Strategy: Search from min(block_size_k, size_k) down to group_size, # stepping by group_size to ensure divisibility by group_size max_search = min(block_size_k, size_k) start = (max_search // group_size) * group_size for candidate in range(start, group_size - 1, -group_size): if size_k % candidate == 0: return candidate # Fallback: if group_size divides size_k, use it # This should always be true with correct group_size configuration if size_k % group_size == 0: return group_size # This should not happen with correct group_size, but ensure divisibility return size_k def get_moe_wna16_block_config( config: dict[str, int], use_moe_wna16_cuda: bool, num_valid_tokens: int, size_k: int, size_n: int, num_experts: int, group_size: int, real_top_k: int, block_size_m: int, ): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: # optimal block config is set return {} if not use_moe_wna16_cuda: # triton moe wna16 kernel if num_valid_tokens // real_top_k == 1: # if bs=1, use a smaller BLOCK_SIZE_N return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} else: return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} else: # cuda moe wna16 kernel # set default block_size 128, and increase them when num_blocks # is too large. block_size_n = 128 block_size_k = 128 if block_size_k <= group_size: block_size_k = group_size num_n_blocks = size_k // block_size_k num_k_blocks = size_n // block_size_k num_m_blocks = ( num_valid_tokens + block_size_m - 1 ) / block_size_m + num_experts if num_valid_tokens // real_top_k <= block_size_m: num_m_blocks = min(num_m_blocks, num_valid_tokens) num_blocks = num_m_blocks * num_n_blocks * num_k_blocks if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: block_size_k = 256 num_blocks = num_blocks // (256 // block_size_k) if ( num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and num_blocks >= 512 ): block_size_k = block_size_k * 2 num_blocks = num_blocks // 2 if num_blocks > 1024: block_size_n = 256 num_n_blocks = num_n_blocks // 2 num_blocks = num_blocks // 2 if size_n <= 1024 and num_blocks >= 1024: # The kernel performance got much better with BLOCK_SIZE_N=1024 # when num_blocks is large, event when N is small. # Not sure why, maybe it force the CUDA SM process only one block # at the same time. block_size_n = 1024 # Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} def should_moe_wna16_use_cuda( num_valid_tokens: int, group_size: int, num_experts: int, bit: int ): return ( current_platform.is_cuda() and bit == 4 and group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 ) def get_default_config( M: int, E: int, N: int, K: int, topk: int, dtype: str | None, block_shape: list[int] | None = None, ) -> dict[str, int]: if vllm_is_batch_invariant(): config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "SPLIT_K": 1, } return config if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] # num_stages=3 can cause triton.runtime.errors.OutOfResources # on ROCm, set it to 2 instead. config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "SPLIT_K": 1, "num_warps": 4, "num_stages": 3 if not current_platform.is_rocm() else 2, } elif dtype in ["int4_w4a16", "int8_w8a16"] and block_shape is not None: # moe wna16 kernels # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1} elif M <= 20: config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1} elif M <= 40: config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1} else: config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1} elif M <= E: config = { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1, } else: config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "SPLIT_K": 1, } return config def try_get_optimal_moe_config( w1_shape: tuple[int, ...], w2_shape: tuple[int, ...], top_k: int, dtype: str | None, M: int, block_shape: list[int] | None = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() if override_config: config = override_config else: # First try to load optimal config from the file E, _, N = w2_shape if dtype == "int4_w4a16": N = N * 2 block_n = block_shape[0] if block_shape else 0 block_k = block_shape[1] if block_shape else 0 configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the # optimal config config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) return config def vllm_topk_softmax( topk_weights: torch.Tensor, topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, ) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, topk_indices, token_expert_indices, gating_output, renormalize, ) return topk_weights, topk_indices def dispatch_topk_func( use_rocm_aiter: bool = False, ) -> Callable[..., tuple[torch.Tensor, ...]]: if use_rocm_aiter: return rocm_aiter_ops.topk_softmax return vllm_topk_softmax def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, indices_type: torch.dtype | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" M, _ = hidden_states.size() topk_weights = torch.empty( M, topk, dtype=torch.float32, device=hidden_states.device ) topk_ids = torch.empty( M, topk, dtype=torch.int32 if indices_type is None else indices_type, device=hidden_states.device, ) token_expert_indices = torch.empty( M, topk, dtype=torch.int32, device=hidden_states.device ) topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) topk_weights, topk_ids = topk_func( topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) return topk_weights, topk_ids, token_expert_indices def fused_topk_bias( hidden_states: torch.Tensor, gating_output: torch.Tensor, e_score_correction_bias: torch.Tensor, topk: int, renormalize: bool, ): n_routed_experts = gating_output.shape[-1] scores = gating_output.softmax(dim=-1) scores_for_choice = scores.view( -1, n_routed_experts ) + e_score_correction_bias.unsqueeze(0) # For batch invariance, use sorted=True to ensure deterministic expert selection use_sorted = vllm_is_batch_invariant() topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights.to(torch.float32), topk_indices.to(torch.int32) # This is used by the Deepseek-V2 and Deepseek-V3 model @torch.compile( dynamic=True, backend=current_platform.simple_compile_backend, options=maybe_disable_graph_partition(current_platform.simple_compile_backend), ) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if ( envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and current_platform.is_cuda() and num_expert_group <= 32 and topk <= 32 and e_score_correction_bias is not None ): return fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, topk=topk, renormalize=renormalize, e_score_correction_bias=e_score_correction_bias, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, ) assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") num_token = scores.size(0) if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) group_scores = ( scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) ) else: group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] # For batch invariance, use sorted=True to ensure deterministic expert selection use_sorted = vllm_is_batch_invariant() group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk( tmp_scores, k=topk, dim=-1, sorted=use_sorted ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) if routed_scaling_factor != 1.0: topk_weights = topk_weights * routed_scaling_factor return topk_weights.to(torch.float32), topk_ids.to(torch.int32) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( topk_ids: torch.Tensor, expert_load_view: torch.Tensor, logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, ) -> torch.Tensor: """ Map the logical expert ids to physical expert ids and record the expert load metrics. This will select a pseudo-random replica for each logical expert. Only used for EPLB. Args: topk_ids: The logical expert ids. expert_load_view: The expert load view. logical_to_physical_map: The logical to physical map. logical_replica_count: The logical replica count. Returns: The physical expert ids. """ # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert # In case `indices_type` is not `torch.long` or `torch.int`, # e.g. `torch.uint32` as required by dispatch/combine kernels topk_ids_long = topk_ids.long() # Use (token position) modulo (replica count) # to deterministically choose a replica replica_count = logical_replica_count[topk_ids_long] # Flatten-position based index, reshaped back to `topk_ids` shape pos_indices = torch.arange( topk_ids.numel(), device=topk_ids.device, dtype=torch.long ).reshape_as(topk_ids) # Compute pseudo-random indices by modulo replica_indices = (pos_indices % replica_count).unsqueeze(-1) physical_ids = ( logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1) ) topk_ids = physical_ids # 2. Record expert load metrics. # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since # `FusedMoEPrepareAndFinalize` will return the expert # token count, in some cases directly from the kernel. # However, now there are many code paths not using # the modular kernel, e.g. calling `fused_experts`, # so we decide to keep the logic here. # # If later refactor moved all the MoE kernel calls # to the modular kernel, we can move this logic there # to achieve better efficiency. # `expert_load_view`: (num_physical_experts,) # `torch.bincount` is not compilable, so use `scatter_add_` instead. topk_ids_flatten = topk_ids.flatten() expert_load_view.scatter_add_( dim=0, index=topk_ids_flatten.long(), src=torch.ones_like(topk_ids_flatten).to(expert_load_view), ) return topk_ids def fused_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, e_score_correction_bias: torch.Tensor, num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "sigmoid": # Fully fused kernel path for sigmoid topk_values, topk_indices = ops.grouped_topk( gating_output, # raw logits num_expert_group, topk_group, topk, renormalize, routed_scaling_factor, e_score_correction_bias.to(gating_output.dtype), 1, # scoring_func=1 for sigmoid ) elif scoring_func == "softmax": # Apply softmax in Python, then use fused kernel # TODO: Add support for softmax in kernel scores = torch.softmax(gating_output, dim=-1) topk_values, topk_indices = ops.grouped_topk( scores, # pre-computed scores num_expert_group, topk_group, topk, renormalize, routed_scaling_factor, e_score_correction_bias.to(gating_output.dtype), 0, # scoring_func=0 (no activation, scores already computed) ) else: raise ValueError(f"Unsupported scoring function: {scoring_func}") # Fused kernel outputs float32 values and int32 indices directly return topk_values, topk_indices def inplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> None: fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, ocp_mx_scheme, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias, ) def inplace_fused_experts_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> None: pass direct_register_custom_op( op_name="inplace_fused_experts", op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, tags=( () if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order,) ), ) def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, ocp_mx_scheme, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias, ) def outplace_fused_experts_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) direct_register_custom_op( op_name="outplace_fused_experts", op_func=outplace_fused_experts, fake_impl=outplace_fused_experts_fake, tags=( () if is_torch_equal_or_newer("2.7.0") else (torch.Tag.needs_fixed_stride_order,) ), ) def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: torch.ops.vllm.inplace_fused_experts(**kwargs) hidden_states = kwargs["hidden_states"] return hidden_states def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: return torch.ops.vllm.outplace_fused_experts(**kwargs) def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: if inplace and not disable_inplace(): return torch_vllm_inplace_fused_experts return torch_vllm_outplace_fused_experts # TODO (bnell): replace this with modular op. Can get rid of inplace/outplace # torch ops. def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, quant_config: FusedMoEQuantConfig | None = None, allow_deep_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG use_fp8_w8a8 = quant_config.use_fp8_w8a8 # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. # However, on B200, we use DeepGemm for all cases because they only support # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. if ( allow_deep_gemm and quant_config.use_fp8_w8a8 and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) ): assert quant_config is not None assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, a1_scale=quant_config.a1_scale, a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) elif ( allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 and _valid_cutlass_block_scaled_grouped_gemm( w1, w2, inplace, activation, apply_router_weight_on_input, expert_map ) ): assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, w2=w2, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, w1=w1, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=quant_config.use_fp8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a16=quant_config.use_int8_w8a16, use_int4_w4a16=quant_config.use_int4_w4a16, ocp_mx_scheme=quant_config.ocp_mx_scheme, per_channel_quant=quant_config.per_act_token_quant, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, w1_zp=quant_config.w1_zp, w2_zp=quant_config.w2_zp, a1_scale=quant_config.a1_scale, a2_scale=quant_config.a2_scale, block_shape=quant_config.block_shape, w1_bias=quant_config.w1_bias, w2_bias=quant_config.w2_bias, ) SILU_NO_MUL: str = activation_without_mul("silu") GELU_NO_MUL: str = activation_without_mul("gelu") RELU2_NO_MUL: str = activation_without_mul("relu2") def _get_config_quant_dtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, ocp_mx_scheme: str | None, ) -> None | torch.dtype | str: """ Get the quantization type based on the quantization strategy flags. We don't have a quant_config at this point so we need to work backwards. A return type of None means no quantization is required because the input is unquantized or has been quantized prior to calling fused_experts_impl. """ if use_fp8_w8a8: return torch.float8_e4m3fn elif use_int8_w8a8: return torch.int8 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": return "mxfp4" elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: return "mxfp6_e3m2" elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: return "mxfp6_e2m3" return None def fused_experts_impl( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, w1_scale: torch.Tensor | None = None, w2_scale: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, block_shape: list[int] | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" elif ocp_mx_scheme is not None: if ocp_mx_scheme in { "w_mxfp4_a_mxfp4", "w_mxfp4_a_mxfp6_e3m2", "w_mxfp4_a_mxfp6_e2m3", }: # 16bit activation and fp4x2 packed weight assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" elif ocp_mx_scheme in { "w_mxfp6_e3m2_a_mxfp6_e3m2", "w_mxfp6_e2m3_a_mxfp6_e2m3", }: assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( "hidden size mismatch" ) else: raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") else: assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" ) assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] num_tokens = hidden_states.size(0) E, N, _ = w1.size() K = w2.size(1) if global_num_experts == -1: global_num_experts = E top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) config_dtype = _get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, ocp_mx_scheme=ocp_mx_scheme, dtype=hidden_states.dtype, ) # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are # quantized prior to calling fused_experts. quant_dtype = _get_config_quant_dtype( use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, ocp_mx_scheme=ocp_mx_scheme, ) get_config_func = functools.partial( try_get_optimal_moe_config, w1.size(), w2.size(), top_k_num, config_dtype, block_shape=block_shape, ) config = get_config_func(M) # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty( M * top_k_num * max(N, K), device=hidden_states.device, dtype=hidden_states.dtype, ) intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty( (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace and not disable_inplace(): out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) if ocp_mx_scheme is not None: # TODO: On platforms for which `current_platform.supports_mx()` is True # and for which we have a native OCP mx fused MOE kernel, # this dequantization step should not be done. if ocp_mx_scheme in { OCP_MX_Scheme.w_mxfp4_a_mxfp4, OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2, OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3, }: # Weight has to be dequantized for mxfp4 emulation. w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1_scale = None w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) w2_scale = None elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: w1 = dequant_mxfp6( w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype ) w1_scale = None w2 = dequant_mxfp6( w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype ) w2_scale = None elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: w1 = dequant_mxfp6( w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype ) w1_scale = None w2 = dequant_mxfp6( w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype ) w2_scale = None else: raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = ( chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens), ) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.size() if tokens_in_chunk == 0: break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[ : tokens_in_chunk * topk_ids.size(1) ] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, ) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map, ignore_invalid_experts=True, ) invoke_fused_moe_kernel( qcurr_hidden_states, w1, intermediate_cache1, a1q_scale, w1_scale, w1_zp, curr_topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, apply_router_weight_on_input, top_k_num, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w1_bias, ) # Activation function with multiplication if activation == "silu": torch.ops._C.silu_and_mul( intermediate_cache2, intermediate_cache1.view(-1, N) ) elif activation == "gelu": torch.ops._C.gelu_and_mul( intermediate_cache2, intermediate_cache1.view(-1, N) ) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 torch.ops._C.swigluoai_and_mul( intermediate_cache2, intermediate_cache1.view(-1, N) ) # Activation function without multiplication elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) elif activation == GELU_NO_MUL: intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N)) elif activation == RELU2_NO_MUL: intermediate_cache2 = torch.square(F.relu(intermediate_cache1.view(-1, N))) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}.") qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, A_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, block_shape=block_shape, ) if expert_map is not None: intermediate_cache3.zero_() invoke_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, w2_scale, w2_zp, curr_topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, not apply_router_weight_on_input, 1, config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, per_channel_quant=per_channel_quant, block_shape=block_shape, B_bias=w2_bias, ) ops.moe_sum( intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], ) return out_hidden_states class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, quant_config: FusedMoEQuantConfig, ): super().__init__(quant_config) @property def activation_formats( self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: return ( mk.FusedMoEActivationFormat.Standard, mk.FusedMoEActivationFormat.Standard, ) def supports_chunking(self) -> bool: return True def supports_expert_map(self) -> bool: return True def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() def workspace_shapes( self, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) return (workspace1, workspace2, output) def apply( self, output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, expert_map: torch.Tensor | None, a1q_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): # Check constraints. if self.quant_config.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: assert hidden_states.size(-1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" ) assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.dim() == 2 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn, ] E, num_tokens, N, K, top_k_num = self.moe_problem_size( hidden_states, w1, w2, topk_ids ) if global_num_experts == -1: global_num_experts = E config = try_get_optimal_moe_config( w1.size(), w2.size(), top_k_num, self.quant_config.config_name(hidden_states.dtype), num_tokens, block_shape=self.block_shape, ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # Note that the output tensor might be in workspace1 intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) intermediate_cache2 = _resize_cache( workspace13, (num_tokens * top_k_num, N // 2) ) intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map ) invoke_fused_moe_kernel( hidden_states, w1, intermediate_cache1, a1q_scale, self.w1_scale, self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, num_tokens_post_padded, False, # mul_routed_weights top_k_num, config, compute_type=compute_type, use_fp8_w8a8=self.quant_config.use_fp8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w1_bias, ) self.activation( activation, intermediate_cache2, intermediate_cache1.view(-1, N) ) a2q_scale: torch.Tensor | None = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( intermediate_cache2, a2_scale, self.quant_dtype, self.per_act_token_quant, self.block_shape, ) invoke_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, self.w2_scale, self.w2_zp, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, not apply_router_weight_on_input, 1, config, compute_type=compute_type, use_fp8_w8a8=self.quant_config.use_fp8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int4_w4a16=self.quant_config.use_int4_w4a16, per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape, B_bias=self.w2_bias, ) # separate function is required for MoE + LoRA self.moe_sum(intermediate_cache3, output) def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: ops.moe_sum(input, output) def modular_triton_fused_moe( quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), shared_experts, )