# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass from enum import IntEnum from typing import Optional, Union import torch import vllm.envs as envs from vllm.config import ParallelConfig from vllm.distributed import ( get_dp_group, get_pcp_group, get_tensor_model_parallel_rank, ) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_DTYPES, OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.import_utils import has_triton_kernels from vllm.utils.math_utils import cdiv logger = init_logger(__name__) if has_triton_kernels(): try: from triton_kernels.matmul_ogs import PrecisionConfig except (ImportError, AttributeError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " "version is compatible. Error: %s", e, ) def _get_config_dtype_str( dtype: torch.dtype, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, ocp_mx_scheme: str | None = None, ) -> str | None: """ Return a string used to construct the filename that contains the tuning info for a particular quantization scheme. See try_get_optimal_moe_config in fused_moe.py. """ if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: return "int4_w4a16" elif ocp_mx_scheme is not None: # The output of this function is passed to `try_get_optimal_moe_config`, # and as we only simulate OCP MX execution in fused_moe for now, # we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now. return None elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs return "float32" return None def _quant_flags_to_group_shape( quant_dtype: torch.dtype | str | None, per_act_token_quant: bool, per_out_ch_quant: bool, block_shape: list[int] | None, ) -> tuple[GroupShape | None, GroupShape | None]: """ Convert MoE quantization flags into more generic GroupShapes. """ a_shape: GroupShape | None w_shape: GroupShape | None if block_shape is not None: assert not per_act_token_quant assert not per_out_ch_quant # TODO(bnell): this is not quite right for activations since first # dim should be 1. a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) else: w_shape = None a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR if per_act_token_quant: a_shape = GroupShape.PER_TOKEN if per_out_ch_quant: w_shape = GroupShape.PER_TOKEN return a_shape, w_shape # The type of method in top-K routing # Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h class RoutingMethodType(IntEnum): # Default: Softmax -> TopK Default = (0,) # Renormalize: TopK -> Softmax Renormalize = (1,) # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups # -> Top8 experts from the Top4 groups DeepSeekV3 = (2,) # Llama4: Top1 -> Sigmoid Llama4 = (3,) # RenormalizeNaive: Softmax -> TopK -> Renormalize RenormalizeNaive = (4,) # TopK: TopK (no softmax) TopK = (5,) # Unspecified Unspecified = 6.0 @dataclass class FusedMoEQuantDesc: """ A quantization descriptor for fused MoE ops. This class can describe either activations or weights. """ # The quantized type of this parameters. None means unquantized or # already quantized. # TODO (bnell): use scalar_type instead of Union. dtype: torch.dtype | str | None = None # A field that describes the quantization group shape, from quant_utils.py. # * (-1, -1) for per-tensor quantization # * (1, -1) for per-row quantization # * (-1, 1) for per-column quantization # * (128, 128) for 128x128 deepseek style block quantization # * (1, 128) for deepseek style activation quantization # (i.e. per-token-per-group) shape: GroupShape | None = None # Quantization scales. # TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc? scale: Union[torch.Tensor, "PrecisionConfig", None] = None # Quantization alphas or gscales, used for nvfp4 types. # W4A8 FP8: used for per-channel scales # TODO(bnell): put some of these in subclasses alpha_or_gscale: torch.Tensor | None = None # Zero points for int4/int8 types zp: torch.Tensor | None = None # Biases for GPT triton MoE bias: torch.Tensor | None = None # TODO(bnell): have subclasses for specific moe methods? # e.g. for specific arguments bias, precision, etc. @dataclass class FusedMoEQuantConfig: """ The FusedMoEQuantConfig contains all the quantization parameters for a single FusedMoEMethodBase operation. It consists of four FusedMoEQuantDescs, one for each activation and set of weights. Each FusedMoEMethodBase must implement a get_fused_moe_quant_config method to construct a FusedMoEQuantConfig for use with that class. FusedMoEQuant configs are only used for modular kernels, fused_experts (from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and triton_kernel_moe_forward. Other MoE methods can ignore the FusedMoEQuantConfig (for now) and hardcode it to None. There are currently some restrictions on what can be expressed: - Most MoE ops only support similar quantization strategies for each parameter, e.g. both weights must have the same GroupShape and both activations must share the same GroupShape. One exception to this is the cutlass moe which allows per channel quantization on the outputs. Note: this restrictions are not always rigorously checked. - Not all fused MoE functions support all the parameters, e.g. zero points, global scales, alphas and biases are not universally supported. - Fully general GroupShapes are not allowed. Activations only support per token, per tensor or K-blocked. - Weights are not required to have a GroupShape since they have already been quantized. Other notes: - PrecisionConfigs are specific to GPT OSS Triton. - As a follow up it would probably make sense to subclass FusedMoEQuantDesc or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses so that only the required quantization parameters are used/stored. """ # TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking _a1: FusedMoEQuantDesc _a2: FusedMoEQuantDesc _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( "illegal quantization" ) # # Convenience accessors for various properties. # @property def quant_dtype(self) -> torch.dtype | str | None: return self._a1.dtype @property def is_quantized(self) -> bool: return self.quant_dtype is not None @property def is_per_act_token(self) -> bool: return self._a1.shape == GroupShape.PER_TOKEN @property def per_act_token_quant(self) -> bool: return self._a1.shape == GroupShape.PER_TOKEN @property def per_out_ch_quant(self) -> bool: return self._w1.shape == GroupShape.PER_TOKEN @property def is_per_tensor(self) -> bool: return self._a1.shape == GroupShape.PER_TENSOR @property def block_shape(self) -> list[int] | None: if ( self._a1.shape is not None and self._a1.shape != GroupShape.PER_TENSOR and self._a1.shape != GroupShape.PER_TOKEN ): return [self._a1.shape.row, self._a1.shape.col] else: return None @property def is_block_quantized(self) -> bool: return self.block_shape is not None @property def a1_scale(self) -> torch.Tensor | None: assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor) return self._a1.scale @property def a1_gscale(self) -> torch.Tensor | None: return self._a1.alpha_or_gscale @property def a2_scale(self) -> torch.Tensor | None: assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor) return self._a2.scale @property def a2_gscale(self) -> torch.Tensor | None: return self._a2.alpha_or_gscale @property def w1_scale(self) -> torch.Tensor | None: assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor) return self._w1.scale @property def w1_zp(self) -> torch.Tensor | None: return self._w1.zp @property def w1_bias(self) -> torch.Tensor | None: return self._w1.bias @property def w1_precision(self) -> Optional["PrecisionConfig"]: assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig) return self._w1.scale @property def g1_alphas(self) -> torch.Tensor | None: return self._w1.alpha_or_gscale @property def w2_scale(self) -> torch.Tensor | None: assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor) return self._w2.scale @property def w2_zp(self) -> torch.Tensor | None: return self._w2.zp @property def w2_bias(self) -> torch.Tensor | None: return self._w2.bias @property def w2_precision(self) -> Optional["PrecisionConfig"]: assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig) return self._w2.scale @property def g2_alphas(self) -> torch.Tensor | None: return self._w2.alpha_or_gscale @property def use_fp8_w8a8(self) -> bool: return self.quant_dtype == torch.float8_e4m3fn @property def use_int8_w8a8(self) -> bool: return self.quant_dtype == torch.int8 @property def use_int8_w8a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == torch.int8 @property def use_int4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "int4" @property def ocp_mx_scheme(self) -> str | None: if not hasattr(self, "_ocp_mx_scheme"): if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or ( self._w1.dtype is not None and not isinstance(self._w1.dtype, str) ): self._ocp_mx_scheme = None else: ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self._a1.dtype, self._w1.dtype ) if ocp_mx_scheme is not None: ocp_mx_scheme = ocp_mx_scheme.value self._ocp_mx_scheme = ocp_mx_scheme return self._ocp_mx_scheme @property def use_mxfp4_w4a16(self) -> bool: return self._a1.dtype is None and self._w1.dtype == "mxfp4" @property def use_mxfp4_w4a4(self) -> bool: return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4" @property def use_nvfp4_w4a4(self) -> bool: return self.quant_dtype == "nvfp4" def config_name(self, dtype: torch.dtype) -> str | None: """ Return a string used to construct the filename that contains the tuning info for a particular quantization scheme. See try_get_optimal_moe_config in fused_moe.py. """ return _get_config_dtype_str( use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, ocp_mx_scheme=self.ocp_mx_scheme, dtype=dtype, ) def scale_shape( self, max_tokens: int, hidden_dim: int, ) -> tuple[int, int] | None: """ Construct the proper activation scale shape for this config. """ if self.is_quantized: if self.is_block_quantized: assert self.block_shape is not None _, block_k = self.block_shape k_tiles = cdiv(hidden_dim, block_k) return (max_tokens, k_tiles) elif self.is_per_act_token: return (max_tokens, 1) else: return (1, 1) else: return None def batched_scale_shape( self, num_experts: int, max_tokens: int, hidden_dim: int, ) -> tuple[int, int, int] | None: """ Construct the proper activation batched scale shape for this config, e.g. (num experts, *scale_shape). """ if self.is_quantized: scale_shape = self.scale_shape(max_tokens, hidden_dim) assert scale_shape is not None return (num_experts, *scale_shape) else: return None @staticmethod def make( quant_dtype: torch.dtype | str | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: list[int] | None = None, w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, g1_alphas: torch.Tensor | None = None, g2_alphas: torch.Tensor | None = None, a1_gscale: torch.Tensor | None = None, a2_gscale: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. - quant_dtype: Optional quantization type. None if activations are unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values for quant_dtype. - per_act_token_quant: Activations have per token quantization. - per_out_ch_quant: Outputs have per channel quantization. (only for cutlass). - block_shape: Optional block size for block-wise quantization. Incompatible with per_act_token and per_out_ch quant. - w1_scale: Optional scale to be used for w1. - w2_scale: Optional scale to be used for w2. - a1_scale: Optional scale to be used for a1. - a2_scale: Optional scale to be used for a2. - g1_alphas: Optional global quantization scales for w1 (for nvfp4). per-channel scales for w1 (for W4A8 FP8). - g2_alphas: Optional global quantization scales for w2 (for nvfp4). per-channel scales for w2 (for W4A8 FP8). - a1_gscale: Optional global quantization scales for a1 (for nvfp4). - a2_gscale: Optional global quantization scales for a2 (for nvfp4). - w1_bias: Optional biases for w1 (GPT OSS Triton). - w2_bias: Optional biases for w1 (GPT OSS Triton). - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3", } assert not isinstance(weight_dtype, str) or weight_dtype in { "nvfp4", "mxfp4", "mxfp6_e3m2", "mxfp6_e2m3", "int4", } if weight_dtype is None: weight_dtype = quant_dtype a_shape, w_shape = _quant_flags_to_group_shape( quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape ) quant_config = FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), _w1=FusedMoEQuantDesc( weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias ), _w2=FusedMoEQuantDesc( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant assert quant_config.block_shape == block_shape return quant_config def fp8_w8a8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: list[int] | None = None, a1_gscale: torch.Tensor | None = None, a2_gscale: torch.Tensor | None = None, g1_alphas: torch.Tensor | None = None, g2_alphas: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for fp8 activations and fp8 weights. """ return FusedMoEQuantConfig.make( torch.float8_e4m3fn, w1_scale=w1_scale, g1_alphas=g1_alphas, w2_scale=w2_scale, g2_alphas=g2_alphas, a1_scale=a1_scale, a1_gscale=a1_gscale, a2_scale=a2_scale, a2_gscale=a2_gscale, per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, ) def int8_w8a8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, a1_scale: torch.Tensor | None, a2_scale: torch.Tensor | None, per_act_token_quant: bool = False, ) -> FusedMoEQuantConfig: """ Construct a quant config for int8 activations and int8 weights. """ return FusedMoEQuantConfig.make( torch.int8, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, per_act_token_quant=per_act_token_quant, per_out_ch_quant=False, block_shape=None, ) def gptq_marlin_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, weight_bits: int, group_size: int, w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ): """ Construct a quant config for gptq marlin quantization. """ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) # Activations are NOT quantized for GPTQ (fp16/bf16) a_shape = w_shape # Same as weight shape for alignment # Determine weight dtype if weight_bits == 4: weight_dtype = "int4" elif weight_bits == 8: weight_dtype = torch.int8 else: raise ValueError(f"Unsupported weight_bits: {weight_bits}") return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), ) def mxfp4_w4a16_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations and mxfp4 weights. """ return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(), _a2=FusedMoEQuantDesc(), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), ) def mxfp4_mxfp8_moe_quant_config( w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. """ return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc("mxfp8"), _a2=FusedMoEQuantDesc("mxfp8"), _w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias), _w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias), ) def ocp_mx_moe_quant_config( quant_dtype: str, w1_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"], weight_dtype: str | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and mxfp4 weights. """ assert quant_dtype in OCP_MX_DTYPES return FusedMoEQuantConfig.make( quant_dtype=quant_dtype, weight_dtype=weight_dtype, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, w1_bias=w1_bias, w2_bias=w2_bias, per_act_token_quant=False, per_out_ch_quant=False, block_shape=block_shape, ) def nvfp4_moe_quant_config( g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, a1_gscale: torch.Tensor, a2_gscale: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. """ return FusedMoEQuantConfig.make( "nvfp4", w1_scale=w1_scale, w2_scale=w2_scale, a1_gscale=a1_gscale, a2_gscale=a2_gscale, g1_alphas=g1_alphas, g2_alphas=g2_alphas, per_act_token_quant=False, per_out_ch_quant=False, block_shape=None, ) def int4_w4a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1_zp: torch.Tensor | None, w2_zp: torch.Tensor | None, block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int4 weights. Note: Activations are pre-quantized. """ group_shape = GroupShape(*block_shape) if block_shape is not None else None return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(shape=group_shape), _a2=FusedMoEQuantDesc(shape=group_shape), _w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp), _w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp), ) def int8_w8a16_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1_zp: torch.Tensor | None, w2_zp: torch.Tensor | None, block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for 16-bit float activations and int8 weights. Note: Activations are pre-quantized. """ group_shape = GroupShape(*block_shape) if block_shape is not None else None return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(shape=group_shape), _a2=FusedMoEQuantDesc(shape=group_shape), _w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp), _w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp), ) def int4_w4afp8_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, g1_alphas: torch.Tensor, g2_alphas: torch.Tensor, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: list[int] | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for fp8 activations and int4 weights. """ return FusedMoEQuantConfig.make( torch.float8_e4m3fn, # quant dtype for activations w1_scale=w1_scale, w2_scale=w2_scale, g1_alphas=g1_alphas, g2_alphas=g2_alphas, per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, weight_dtype="int4", # weight dtype for weights ) def awq_marlin_moe_quant_config( w1_scale: torch.Tensor, w2_scale: torch.Tensor, w1_zp: torch.Tensor | None, w2_zp: torch.Tensor | None, weight_bits: int, group_size: int, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for awq marlin quantization. """ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size) # Activations are NOT quantized for AWQ (fp16/bf16) a_shape = w_shape # Same as weight shape for alignment # Determine weight dtype if weight_bits == 4: weight_dtype = "int4" elif weight_bits == 8: weight_dtype = torch.int8 else: raise ValueError(f"Unsupported weight_bits: {weight_bits}") return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(dtype=None, shape=a_shape), _a2=FusedMoEQuantDesc(dtype=None, shape=a_shape), _w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias), _w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias), ) def biased_moe_quant_config( w1_bias: torch.Tensor | None, w2_bias: torch.Tensor | None, ) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations with biases. """ return FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(), _a2=FusedMoEQuantDesc(), _w1=FusedMoEQuantDesc(bias=w1_bias), _w2=FusedMoEQuantDesc(bias=w2_bias), ) # A FusedMoEQuantConfig constant for an unquantized MoE op. FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make() @dataclass class FusedMoEParallelConfig: tp_size: int pcp_size: int dp_size: int ep_size: int tp_rank: int pcp_rank: int dp_rank: int ep_rank: int use_ep: bool # whether to use EP or not all2all_backend: str # all2all backend for MoE communication @property def use_all2all_kernels(self): return self.dp_size > 1 and self.use_ep @property def use_pplx_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "pplx" @property def use_deepep_ht_kernels(self): return ( self.use_all2all_kernels and self.all2all_backend == "deepep_high_throughput" ) @property def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int ) -> tuple[int, int]: tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank() # There are actually dp_size * pcp_size * tp_size devices. # Update tp_size and tp_rank so we shard across all devices. flatten_tp_size = dp_size * pcp_size * tp_size flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank return flatten_tp_size, flatten_tp_rank @staticmethod def make( tp_size_: int, pcp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig, ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, `dp_size_` and vllm's parallel config, determine what level's of parallelism to use in the fused moe layer. Args: tp_size_ (int): `tp_size` passed into the FusedMoE constructor. pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. Examples: When there is no parallelism requested, i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes unaltered and the ranks set to 0. Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or `tp_size_` is non trivial. Note that PCP serves the same function as DP here. When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different devices: - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank} - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - Comment : Tensors are sharded across 2 devices. When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different devices: - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different devices: - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - Comment: The experts are split between the 2 devices. When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - Comment: There are 2 engine instances and the experts are split between the 2 devices. When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - Comment: There are 2 engine instances and the experts are split between the 4 devices. """ use_ep = ( dp_size_ * pcp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel ) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 pcp_size = pcp_size_ pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0 tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank ) if not use_ep: return FusedMoEParallelConfig( tp_size=tp_size, tp_rank=tp_rank, pcp_size=pcp_size, pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, ep_rank=0, use_ep=False, all2all_backend=vllm_parallel_config.all2all_backend, ) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig( tp_size=1, tp_rank=0, pcp_size=pcp_size, pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, ep_rank=ep_rank, use_ep=True, all2all_backend=vllm_parallel_config.all2all_backend, ) # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class FusedMoEConfig: num_experts: int experts_per_token: int hidden_dim: int num_local_experts: int moe_parallel_config: FusedMoEParallelConfig # The activation type. in_dtype: torch.dtype max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False is_act_and_mul: bool = True is_lora_enabled: bool = False def __post_init__(self): if self.dp_size > 1: logger.debug_once( "Using FusedMoEConfig::max_num_tokens=%d", self.max_num_tokens ) assert self.max_num_tokens > 0 @property def tp_size(self): return self.moe_parallel_config.tp_size @property def dp_size(self): return self.moe_parallel_config.dp_size @property def ep_size(self): return self.moe_parallel_config.ep_size @property def tp_rank(self): return self.moe_parallel_config.tp_rank @property def dp_rank(self): return self.moe_parallel_config.dp_rank @property def ep_rank(self): return self.moe_parallel_config.ep_rank @property def use_ep(self): return self.moe_parallel_config.use_ep @property def use_pplx_kernels(self): return self.moe_parallel_config.use_pplx_kernels @property def use_deepep_ht_kernels(self): return self.moe_parallel_config.use_deepep_ht_kernels @property def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels @property def use_flashinfer_cutlass_kernels(self): """ Whether to use FlashInfer cutlass kernels for NVFP4 MoE. """ return ( envs.VLLM_USE_FLASHINFER_MOE_FP4 and has_flashinfer_cutlass_fused_moe() and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" )