[gpt-oss] Add gpt-oss mxfp4 support
This commit is contained in:
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
490
vllm/model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import cdiv
|
||||
# from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quant_config_quantization_args(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prop_name: str,
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_quant_config_input_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config,
|
||||
"input_activations")
|
||||
|
||||
|
||||
def get_quant_config_weight_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
# The post quantization activation type.
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
per_act_token_quant: bool = False
|
||||
per_out_ch_quant: bool = False
|
||||
block_shape: Optional[list[int]] = None
|
||||
|
||||
# TODO: add col major flag?
|
||||
# add detailed quant info for input, intermediates, weights, etc?
|
||||
|
||||
def __post_init__(self):
|
||||
assert (not self.per_act_token_quant
|
||||
or self.block_shape is None), "illegal quantization"
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return not self.per_act_token_quant and self.block_shape is None
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = cdiv(hidden_dim, block_k)
|
||||
return (max_tokens, k_tiles)
|
||||
elif self.is_per_act_token:
|
||||
return (max_tokens, 1)
|
||||
else:
|
||||
return (1, 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def batched_scale_shape(
|
||||
self,
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
return (num_experts, *scale_shape)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
assert sum([
|
||||
int(flag) for flag in [
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
]
|
||||
]) <= 1, "Quantization flags are mutually exclusive."
|
||||
|
||||
quant_dtype = get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
)
|
||||
return FusedMoEQuantConfig(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
# return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
# and has_flashinfer_cutlass_fused_moe()
|
||||
# and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
`dp_size_` and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.block_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_act_token_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
experts_per_token: int,
|
||||
hidden_dim: int,
|
||||
num_local_experts: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
in_dtype: torch.dtype,
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||
QuantizationConfig]] = None,
|
||||
has_bias: bool = False,
|
||||
) -> "FusedMoEConfig":
|
||||
|
||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
if quant_config is not None and isinstance(quant_config,
|
||||
QuantizationConfig):
|
||||
if hasattr(quant_config, 'weight_block_size'):
|
||||
block_shape = quant_config.weight_block_size
|
||||
else:
|
||||
block_shape = None
|
||||
per_act_token_quant = False
|
||||
per_out_ch_quant = False
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
|
||||
input_quant = get_quant_config_input_quant(quant_config)
|
||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||
|
||||
if input_quant is not None:
|
||||
per_act_token_quant = (input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_quant is not None else False)
|
||||
|
||||
if input_quant.num_bits == 8:
|
||||
if input_quant.type == QuantizationType.FLOAT:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif input_quant.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
ModelOptNvFp4Config):
|
||||
quant_dtype = torch.uint8
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
if quant_dtype is not None:
|
||||
_quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
_quant_config = FusedMoEQuantConfig()
|
||||
if moe_parallel_config.dp_size > 1:
|
||||
logger.warning_once("MoE DP setup unable to determine "
|
||||
"quantization scheme or unsupported "
|
||||
"quantization type. This model will "
|
||||
"not run with DP enabled.")
|
||||
else:
|
||||
_quant_config = quant_config
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=in_dtype,
|
||||
quant_config=_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
Reference in New Issue
Block a user