v1.0
This commit is contained in:
916
model_executor/layers/fused_moe/config.py
Normal file
916
model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,916 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_DTYPES,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible."
|
||||
)
|
||||
|
||||
|
||||
def _get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
ocp_mx_scheme: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif ocp_mx_scheme is not None:
|
||||
# The output of this function is passed to `try_get_optimal_moe_config`,
|
||||
# and as we only simulate OCP MX execution in fused_moe for now,
|
||||
# we will NOT look for `*,dtype=w_mxfp4_a_mxfp4.json` for now.
|
||||
return None
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def _quant_flags_to_group_shape(
|
||||
quant_dtype: torch.dtype | str | None,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
) -> tuple[GroupShape | None, GroupShape | None]:
|
||||
"""
|
||||
Convert MoE quantization flags into more generic GroupShapes.
|
||||
"""
|
||||
a_shape: GroupShape | None
|
||||
w_shape: GroupShape | None
|
||||
if block_shape is not None:
|
||||
assert not per_act_token_quant
|
||||
assert not per_out_ch_quant
|
||||
# TODO(bnell): this is not quite right for activations since first
|
||||
# dim should be 1.
|
||||
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
else:
|
||||
w_shape = None
|
||||
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
|
||||
|
||||
if per_act_token_quant:
|
||||
a_shape = GroupShape.PER_TOKEN
|
||||
|
||||
if per_out_ch_quant:
|
||||
w_shape = GroupShape.PER_TOKEN
|
||||
|
||||
return a_shape, w_shape
|
||||
|
||||
|
||||
# The type of method in top-K routing
|
||||
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
|
||||
class RoutingMethodType(IntEnum):
|
||||
# Default: Softmax -> TopK
|
||||
Default = (0,)
|
||||
# Renormalize: TopK -> Softmax
|
||||
Renormalize = (1,)
|
||||
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||
# -> Top8 experts from the Top4 groups
|
||||
DeepSeekV3 = (2,)
|
||||
# Llama4: Top1 -> Sigmoid
|
||||
Llama4 = (3,)
|
||||
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
# Unspecified
|
||||
Unspecified = 6.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
A quantization descriptor for fused MoE ops. This class can describe
|
||||
either activations or weights.
|
||||
"""
|
||||
|
||||
# The quantized type of this parameters. None means unquantized or
|
||||
# already quantized.
|
||||
# TODO (bnell): use scalar_type instead of Union.
|
||||
dtype: torch.dtype | str | None = None
|
||||
|
||||
# A field that describes the quantization group shape, from quant_utils.py.
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
# * (1, -1) for per-row quantization
|
||||
# * (-1, 1) for per-column quantization
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
shape: GroupShape | None = None
|
||||
|
||||
# Quantization scales.
|
||||
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
|
||||
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
||||
|
||||
# Quantization alphas or gscales, used for nvfp4 types.
|
||||
# TODO(bnell): put some of these in subclasses
|
||||
alpha_or_gscale: torch.Tensor | None = None
|
||||
|
||||
# Zero points for int4/int8 types
|
||||
zp: torch.Tensor | None = None
|
||||
|
||||
# Biases for GPT triton MoE
|
||||
bias: torch.Tensor | None = None
|
||||
|
||||
|
||||
# TODO(bnell): have subclasses for specific moe methods?
|
||||
# e.g. for specific arguments bias, precision, etc.
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
"""
|
||||
The FusedMoEQuantConfig contains all the quantization parameters for
|
||||
a single FusedMoEMethodBase operation. It consists of four
|
||||
FusedMoEQuantDescs, one for each activation and set of weights.
|
||||
|
||||
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
|
||||
method to construct a FusedMoEQuantConfig for use with that class.
|
||||
|
||||
FusedMoEQuant configs are only used for modular kernels, fused_experts
|
||||
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
|
||||
triton_kernel_moe_forward. Other MoE methods can ignore the
|
||||
FusedMoEQuantConfig (for now) and hardcode it to None.
|
||||
|
||||
There are currently some restrictions on what can be expressed:
|
||||
- Most MoE ops only support similar quantization strategies for
|
||||
each parameter, e.g. both weights must have the same GroupShape
|
||||
and both activations must share the same GroupShape. One exception to
|
||||
this is the cutlass moe which allows per channel quantization on the
|
||||
outputs. Note: this restrictions are not always rigorously checked.
|
||||
- Not all fused MoE functions support all the parameters, e.g. zero points,
|
||||
global scales, alphas and biases are not universally supported.
|
||||
- Fully general GroupShapes are not allowed. Activations only support
|
||||
per token, per tensor or K-blocked.
|
||||
- Weights are not required to have a GroupShape since they have already
|
||||
been quantized.
|
||||
|
||||
Other notes:
|
||||
- PrecisionConfigs are specific to GPT OSS Triton.
|
||||
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
|
||||
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
|
||||
so that only the required quantization parameters are used/stored.
|
||||
"""
|
||||
|
||||
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
|
||||
_a1: FusedMoEQuantDesc
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.per_act_token_quant or self.block_shape is None, (
|
||||
"illegal quantization"
|
||||
)
|
||||
|
||||
#
|
||||
# Convenience accessors for various properties.
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> torch.dtype | str | None:
|
||||
return self._a1.dtype
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self._w1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TENSOR
|
||||
|
||||
@property
|
||||
def block_shape(self) -> list[int] | None:
|
||||
if (
|
||||
self._a1.shape is not None
|
||||
and self._a1.shape != GroupShape.PER_TENSOR
|
||||
and self._a1.shape != GroupShape.PER_TOKEN
|
||||
):
|
||||
return [self._a1.shape.row, self._a1.shape.col]
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> torch.Tensor | None:
|
||||
assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor)
|
||||
return self._a1.scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> torch.Tensor | None:
|
||||
return self._a1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> torch.Tensor | None:
|
||||
assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor)
|
||||
return self._a2.scale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> torch.Tensor | None:
|
||||
return self._a2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> torch.Tensor | None:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> torch.Tensor | None:
|
||||
return self._w1.zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> torch.Tensor | None:
|
||||
return self._w1.bias
|
||||
|
||||
@property
|
||||
def w1_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> torch.Tensor | None:
|
||||
return self._w1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> torch.Tensor | None:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> torch.Tensor | None:
|
||||
return self._w2.zp
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> torch.Tensor | None:
|
||||
return self._w2.bias
|
||||
|
||||
@property
|
||||
def w2_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> torch.Tensor | None:
|
||||
return self._w2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def use_fp8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
@property
|
||||
def use_int8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.int8
|
||||
|
||||
@property
|
||||
def use_int8_w8a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == torch.int8
|
||||
|
||||
@property
|
||||
def use_int4_w4a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == "int4"
|
||||
|
||||
@property
|
||||
def ocp_mx_scheme(self) -> str | None:
|
||||
if not hasattr(self, "_ocp_mx_scheme"):
|
||||
if (self._a1.dtype is not None and not isinstance(self._a1.dtype, str)) or (
|
||||
self._w1.dtype is not None and not isinstance(self._w1.dtype, str)
|
||||
):
|
||||
self._ocp_mx_scheme = None
|
||||
else:
|
||||
ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self._a1.dtype, self._w1.dtype
|
||||
)
|
||||
|
||||
if ocp_mx_scheme is not None:
|
||||
ocp_mx_scheme = ocp_mx_scheme.value
|
||||
|
||||
self._ocp_mx_scheme = ocp_mx_scheme
|
||||
|
||||
return self._ocp_mx_scheme
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a16(self) -> bool:
|
||||
return self._a1.dtype is None and self._w1.dtype == "mxfp4"
|
||||
|
||||
@property
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "nvfp4"
|
||||
|
||||
def config_name(self, dtype: torch.dtype) -> str | None:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
return _get_config_dtype_str(
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
ocp_mx_scheme=self.ocp_mx_scheme,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> tuple[int, int] | None:
|
||||
"""
|
||||
Construct the proper activation scale shape for this
|
||||
config.
|
||||
"""
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = cdiv(hidden_dim, block_k)
|
||||
return (max_tokens, k_tiles)
|
||||
elif self.is_per_act_token:
|
||||
return (max_tokens, 1)
|
||||
else:
|
||||
return (1, 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def batched_scale_shape(
|
||||
self,
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> tuple[int, int, int] | None:
|
||||
"""
|
||||
Construct the proper activation batched scale shape for this
|
||||
config, e.g. (num experts, *scale_shape).
|
||||
"""
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
return (num_experts, *scale_shape)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
g1_alphas: torch.Tensor | None = None,
|
||||
g2_alphas: torch.Tensor | None = None,
|
||||
a1_gscale: torch.Tensor | None = None,
|
||||
a2_gscale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
w1_zp: torch.Tensor | None = None,
|
||||
w2_zp: torch.Tensor | None = None,
|
||||
weight_dtype: torch.dtype | str | None = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
- quant_dtype: Optional quantization type. None if activations are
|
||||
unquantized or quantized prior to calling. Note: "nvfp4", "mxfp4",
|
||||
"mxfp6_e3m2", "mxfp6_e2m3" are the only valid string values
|
||||
for quant_dtype.
|
||||
- per_act_token_quant: Activations have per token quantization.
|
||||
- per_out_ch_quant: Outputs have per channel quantization. (only
|
||||
for cutlass).
|
||||
- block_shape: Optional block size for block-wise quantization.
|
||||
Incompatible with per_act_token and per_out_ch quant.
|
||||
- w1_scale: Optional scale to be used for w1.
|
||||
- w2_scale: Optional scale to be used for w2.
|
||||
- a1_scale: Optional scale to be used for a1.
|
||||
- a2_scale: Optional scale to be used for a2.
|
||||
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
"""
|
||||
assert not isinstance(quant_dtype, str) or quant_dtype in {
|
||||
"nvfp4",
|
||||
"mxfp4",
|
||||
"mxfp6_e3m2",
|
||||
"mxfp6_e2m3",
|
||||
}
|
||||
assert not isinstance(weight_dtype, str) or weight_dtype in {
|
||||
"nvfp4",
|
||||
"mxfp4",
|
||||
"mxfp6_e3m2",
|
||||
"mxfp6_e2m3",
|
||||
}
|
||||
|
||||
if weight_dtype is None:
|
||||
weight_dtype = quant_dtype
|
||||
|
||||
a_shape, w_shape = _quant_flags_to_group_shape(
|
||||
quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape
|
||||
)
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
|
||||
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
|
||||
_w1=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias
|
||||
),
|
||||
_w2=FusedMoEQuantDesc(
|
||||
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
|
||||
),
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
assert quant_config.block_shape == block_shape
|
||||
return quant_config
|
||||
|
||||
|
||||
def fp8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
a1_gscale: torch.Tensor | None = None,
|
||||
a2_gscale: torch.Tensor | None = None,
|
||||
g1_alphas: torch.Tensor | None = None,
|
||||
g2_alphas: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.float8_e4m3fn,
|
||||
w1_scale=w1_scale,
|
||||
g1_alphas=g1_alphas,
|
||||
w2_scale=w2_scale,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_scale=a1_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_scale=a2_scale,
|
||||
a2_gscale=a2_gscale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def int8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for int8 activations and int8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_mxfp8_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc("mxfp8"),
|
||||
_a2=FusedMoEQuantDesc("mxfp8"),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def ocp_mx_moe_quant_config(
|
||||
quant_dtype: str,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
weight_dtype: str | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
assert quant_dtype in OCP_MX_DTYPES
|
||||
return FusedMoEQuantConfig.make(
|
||||
quant_dtype=quant_dtype,
|
||||
weight_dtype=weight_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def nvfp4_moe_quant_config(
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"nvfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def int4_w4a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int4 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def int8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: torch.Tensor | None,
|
||||
w2_zp: torch.Tensor | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int8 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations with biases.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc(bias=w1_bias),
|
||||
_w2=FusedMoEQuantDesc(bias=w2_bias),
|
||||
)
|
||||
|
||||
|
||||
# A FusedMoEQuantConfig constant for an unquantized MoE op.
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
all2all_backend: str # all2all backend for MoE communication
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "pplx"
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (
|
||||
self.use_all2all_kernels
|
||||
and self.all2all_backend == "deepep_high_throughput"
|
||||
)
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@staticmethod
|
||||
def flatten_tp_across_dp(
|
||||
tp_size: int, dp_size: int, dp_rank: int
|
||||
) -> tuple[int, int]:
|
||||
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size * tp_size devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
flatten_tp_size = dp_size * tp_size
|
||||
flatten_tp_rank = dp_rank * tp_size + tp_rank
|
||||
return flatten_tp_size, flatten_tp_rank
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
|
||||
) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
`dp_size_` and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size_, dp_size_, dp_rank
|
||||
)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
is_act_and_mul: bool = True
|
||||
|
||||
is_lora_enabled: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once(
|
||||
"Using FusedMoEConfig::max_num_tokens=%d", self.max_num_tokens
|
||||
)
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
"""
|
||||
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||
"""
|
||||
return (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput"
|
||||
)
|
||||
Reference in New Issue
Block a user