[gpt-oss] Add gpt-oss mxfp4 support

This commit is contained in:
2025-08-25 15:31:09 +08:00
parent db7f48eeac
commit 7a35b2f32d
32 changed files with 4835 additions and 1190 deletions

View File

@@ -4,8 +4,12 @@
from contextlib import contextmanager
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
@@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
__all__ = [
"FusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"override_config",
"get_config",
]
@@ -36,11 +44,21 @@ if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8)
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [
"fused_moe",
@@ -50,5 +68,11 @@ if HAS_TRITON:
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"CutlassExpertsFp8",
"TritonExperts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
]

View File

@@ -0,0 +1,490 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import cdiv
# from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
def _get_quant_config_quantization_args(
quant_config: Optional[QuantizationConfig],
prop_name: str,
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(prop_name)
else:
return None
def get_quant_config_input_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config,
"input_activations")
def get_quant_config_weight_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config, "weights")
# TODO (bnell): use scalar_type instead of bools?
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
elif use_mxfp4_w4a4:
return "mxfp4"
return None
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
# TODO: add col major flag?
# add detailed quant info for input, intermediates, weights, etc?
def __post_init__(self):
assert (not self.per_act_token_quant
or self.block_shape is None), "illegal quantization"
@property
def is_quantized(self) -> bool:
return self.quant_dtype is not None
@property
def is_per_act_token(self) -> bool:
return self.per_act_token_quant
@property
def is_block_quantized(self) -> bool:
return self.block_shape is not None
@property
def is_per_tensor(self) -> bool:
return not self.per_act_token_quant and self.block_shape is None
def scale_shape(
self,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int]]:
if self.is_quantized:
if self.is_block_quantized:
assert self.block_shape is not None
_, block_k = self.block_shape
k_tiles = cdiv(hidden_dim, block_k)
return (max_tokens, k_tiles)
elif self.is_per_act_token:
return (max_tokens, 1)
else:
return (1, 1)
else:
return None
def batched_scale_shape(
self,
num_experts: int,
max_tokens: int,
hidden_dim: int,
) -> Optional[tuple[int, int, int]]:
if self.is_quantized:
scale_shape = self.scale_shape(max_tokens, hidden_dim)
assert scale_shape is not None
return (num_experts, *scale_shape)
else:
return None
@staticmethod
def make(
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> "FusedMoEQuantConfig":
assert sum([
int(flag) for flag in [
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
]
]) <= 1, "Quantization flags are mutually exclusive."
quant_dtype = get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_mxfp4_w4a4=use_mxfp4_w4a4,
)
return FusedMoEQuantConfig(
quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@dataclass
class FusedMoEParallelConfig:
tp_size: int
dp_size: int
ep_size: int
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # whether to use EP or not
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
@property
def use_deepep_ht_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
@property
def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
# return (envs.VLLM_USE_FLASHINFER_MOE_FP4
# and has_flashinfer_cutlass_fused_moe()
# and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
return False
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input `tp_size_`,
`dp_size_` and vllm's parallel config, determine what
level's of parallelism to use in the fused moe layer.
Args:
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag.
Examples:
When there is no parallelism requested,
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
unaltered and the ranks set to 0.
Expert Parallelism is considered only when either `dp_size_` or
`tp_size_` is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
legend : {size, rank}
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
devices:
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = (dp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
use_ep=False)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
use_ep=True)
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class FusedMoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
# The activation type.
in_dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig] = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
self.max_num_tokens)
assert self.max_num_tokens > 0
@property
def quant_dtype(self) -> Optional[torch.dtype]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
return None
@property
def block_shape(self) -> Optional[list[int]]:
if self.quant_config is not None:
return self.quant_config.block_shape
else:
return None
@property
def per_act_token_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_act_token_quant
else:
return False
@property
def per_out_ch_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_out_ch_quant
else:
return False
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
@staticmethod
def make(
num_experts: int,
experts_per_token: int,
hidden_dim: int,
num_local_experts: int,
moe_parallel_config: FusedMoEParallelConfig,
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
QuantizationConfig]] = None,
has_bias: bool = False,
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
if quant_config is not None and isinstance(quant_config,
QuantizationConfig):
if hasattr(quant_config, 'weight_block_size'):
block_shape = quant_config.weight_block_size
else:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Optional[torch.dtype] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
if input_quant is not None:
per_act_token_quant = (input_quant.strategy
== QuantizationStrategy.TOKEN
if input_quant is not None else False)
if input_quant.num_bits == 8:
if input_quant.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_quant.type == QuantizationType.INT:
quant_dtype = torch.int8
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8
if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)
if quant_dtype is not None:
_quant_config = FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
)
else:
_quant_config = FusedMoEQuantConfig()
if moe_parallel_config.dp_size > 1:
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
else:
_quant_config = quant_config
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
has_bias=has_bias,
)

View File

@@ -1503,8 +1503,8 @@ def fused_experts_impl(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,
qtype=qtype,
per_channel_quant=per_channel_quant,
quant_dtype=qtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qcurr_hidden_states,
@@ -1562,8 +1562,8 @@ def fused_experts_impl(
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
qtype=qtype,
per_channel_quant=per_channel_quant,
quant_dtype=qtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,

View File

@@ -0,0 +1,248 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils import has_triton_kernels
logger = init_logger(__name__)
if has_triton_kernels():
try:
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs)
from triton_kernels.routing import routing
except ModuleNotFoundError:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
if TYPE_CHECKING:
from triton_kernels.matmul_ogs import PrecisionConfig
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
routing_data, gather_idx, scatter_idx = routing(gating_output,
topk,
sm_first=not renormalize)
return triton_kernel_fused_experts(
None,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=w1_precision,
w2_precision=w2_precision,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
# This is a triton implementation of the fused_experts function
def triton_kernel_fused_experts(
output_tensor: torch.Tensor,
hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor
w2, # Tensor or triton_kernels.Tensor
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
activation: str = "silu",
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_precision: Optional["PrecisionConfig"] = None,
w2_precision: Optional["PrecisionConfig"] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
assert w1_bias is None or w1_bias.dtype == torch.float32
assert w2_bias is None or w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit), 2)
gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
return intermediate_cache3
class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
quant_config,
max_num_tokens: int,
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# workspace are allocated inside the kernel
assert a.dim() == 2
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
workspace2 = (0, 0, 0)
output = (num_experts, max_num_tokens * num_dp, N)
return (output, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
["w1_bias", "w2_bias"]))
return triton_kernel_fused_experts(
output,
hidden_states,
w1,
w2,
None,
None,
None,
activation=activation,
apply_router_weight_on_input=False,
use_fp8_w8a8=False,
per_channel_quant=False,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,
a2_scale=a2_scale)

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Any, Optional, final
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
_resize_cache, count_expert_num_tokens)
from vllm.utils import cdiv
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
@@ -14,7 +23,7 @@ import torch
#
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
#
# Each component will be independent of the others except for
# Each component will be independent of (but may inform) the others except for
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
# mixed and matched with so that DP+EP can be supported easily for multiple
# MoE kernel implementations.
@@ -23,13 +32,19 @@ import torch
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
# finalize method must apply weights and do the final reduction of the output.
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
# may apply weights and/or do the final reduction of the output.
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
# MoE operation. One important feature to note is that this class does not
# apply topk weights or reduce the final output.
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
# provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
# on to [Finalize].
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
# class `FusedMoEPrepareAndFinalize` since they could use collective
@@ -77,6 +92,56 @@ def _moe_problem_size(
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
"""
Standard = "standard",
"""
The batched experts format (num experts, max tokens per expert, hidden dim)
"""
BatchedExperts = "batched_experts",
@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]
@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)
class TopKWeightAndReduce(ABC):
"""
An abstract base class for weight application and reduction implementations.
"""
@abstractmethod
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
"""
Apply topk_weights to the fused_experts_outputs and/or reduce.
If an output tensor is not passed, it will be created in the
function.
"""
raise NotImplementedError
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
@@ -85,17 +150,15 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@@ -114,22 +177,20 @@ class FusedMoEPrepareAndFinalize(ABC):
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
number of tokens assigned to each local expert.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
- Optional dispatched expert topk weight
"""
raise NotImplementedError
@abstractmethod
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
@@ -140,6 +201,17 @@ class FusedMoEPrepareAndFinalize(ABC):
- topk_ids: The topk_ids.
- apply_router_weight_on_input: When False, apply the weights to
fused_expert_output.
- weight_and_reduce_impl: An optional TopKWeightAndReduce
implementation.
"""
raise NotImplementedError
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
"""
A property indicating the output format of the activations for the
'prepare' method.
"""
raise NotImplementedError
@@ -159,11 +231,15 @@ class FusedMoEPrepareAndFinalize(ABC):
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@abstractmethod
def num_dispatchers(self) -> int:
raise NotImplementedError
class FusedMoEPermuteExpertsUnpermute(ABC):
"""
@@ -171,6 +247,57 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above.
"""
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise NotImplementedError
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
return self.quant_config.block_shape
@property
def per_act_token_quant(self) -> bool:
return self.quant_config.per_act_token_quant
@property
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
@@ -180,20 +307,25 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
N: int,
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the number of elements for the temporary outputs of the two
gemms and activation in the fused expert function. Since the
gemms are independent, the workspace for the first gemm can be shared
with the workspace for the last gemm.
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise NotImplementedError
@@ -207,12 +339,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
def enable_chunking(self):
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
self.supports_chunking()
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@abstractmethod
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
@@ -225,17 +366,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights: A map of row to expert weights. Some implementations
choose to do weight application.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
MoE layer.
@@ -257,15 +403,28 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- apply_router_weight_on_input: True if router weights are already
applied on the input. This is relevant if the implementation
chooses to do weight application.
"""
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
@@ -287,46 +446,56 @@ class FusedMoEModularKernel(torch.nn.Module):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.empty(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
workspace2 = torch.empty(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
assert fused_out is None or fused_out.shape == fused_out_shape, (
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
if fused_out is None:
# reuse workspace13 for the output
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
@@ -338,8 +507,162 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
return fused_out
def _maybe_chunk_fused_experts(
self,
a1: torch.Tensor,
a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts(
fused_out=None,
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
# Chunking required case
assert num_chunks > 1
# Construct the entire output that can then be processed in chunks.
(_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts,
expert_tokens_meta)
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=a1.dtype)
def slice_input_tensors(
chunk_idx: int
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
_chunk_scales(a2_scale, s,
e), topk_ids[s:e], topk_weights[s:e])
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
assert fused_out.size(0) % M == 0, (
f"fused_out shape {fused_out.shape} vs M {M}")
factor = fused_out.size(0) // M
out_chunk_size = CHUNK_SIZE * factor
s = chunk_idx * out_chunk_size
e = min(s + out_chunk_size, fused_out.size(0))
return fused_out[s:e]
def slice_expert_tokens_metadata(
full_expert_tokens_meta: ExpertTokensMetadata,
chunk_topk_ids: torch.Tensor, local_num_experts: int,
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
# The existing expert_num_tokens is for the entire a1q
# input. Chunking forces recomputation of the number
# of tokens assigned to each expert.
c_expert_num_tokens = count_expert_num_tokens(
chunk_topk_ids, local_num_experts, expert_map)
c_expert_num_tokens_cpu = None
need_expert_num_tokens_cpu = (
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
if need_expert_num_tokens_cpu:
# This is blocking as some implementations need the count
# on the CPU to determine appropriate input/out fused-moe
# buffers
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
"cpu", non_blocking=False)
return ExpertTokensMetadata(
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
c_expert_tokens_meta = None
if expert_tokens_meta is not None:
c_expert_tokens_meta = slice_expert_tokens_metadata(
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
a1q=c_a1q,
w1=w1,
w2=w2,
topk_weights=c_topk_weights,
topk_ids=c_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
return fused_out
@@ -361,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -393,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -401,19 +733,31 @@ class FusedMoEModularKernel(torch.nn.Module):
a1 = hidden_states
output = a1 if inplace else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = w1.size(0)
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
fused_out = None
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
@@ -423,24 +767,31 @@ class FusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self._do_fused_experts(
fused_out = self._maybe_chunk_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_num_tokens=expert_num_tokens,
activation=activation,
global_num_experts=global_num_experts,
local_num_experts=local_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale)
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
return output

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
Useful in the case when some FusedMoEPermuteExpertsUnpermute
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
For example, BatchedTritonExperts is compatible with both
PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
does the weight-application + reduction as part of the pplx combine kernel.
But the BatchedPrepareAndFinalize needs an implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceDelegate)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
raise RuntimeError("The caller is expected to choose an appropriate "
"TopKWeightAndReduce implementation.")
class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
"""
The fused_experts outputs have already been weight applied and reduced.
This implementation is a no-op.
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceNoOP)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
# Weight application and reduction operations are already done.
if output is None:
return fused_expert_output
# MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
f"But got output={output.size()}, "
f"used_expert_output={fused_expert_output.size()}")
output.copy_(fused_expert_output, non_blocking=True)
return output
class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (m, topk, K)
"""
def __eq__(self, other):
return isinstance(other, TopKWeightAndReduceContiguous)
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
m, num_topk = topk_ids.size()
k = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
fused_expert_output = fused_expert_output.view(m, num_topk, k)
assert fused_expert_output.size() == (m, num_topk, k), (
f"Expected fused_expert_output size {(m, num_topk, k)}. But got "
f"{fused_expert_output.size()}")
if not apply_router_weight_on_input:
fused_expert_output.mul_(topk_weights.view(m, -1, 1))
if output is None:
output = torch.empty((m, k),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
assert output.size() == (m, k), (
f"Expected output size {(m, k)}. But got {output.size()}")
ops.moe_sum(fused_expert_output, output)
return output
class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce):
"""
TopKWeightAndReduce implementation for a fused_experts output
of shape (num_experts, batch_size, K)
"""
def __init__(self, rank: int):
self.rank = rank
def __eq__(self, other):
return (isinstance(other, TopKWeightAndReduceNaiveBatched)
and (other.rank == self.rank))
def apply(self, output: Optional[torch.Tensor],
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool) -> torch.Tensor:
assert fused_expert_output.ndim == 3
num_tokens = topk_ids.size(0)
num_local_experts = fused_expert_output.size(0)
K = fused_expert_output.size(-1)
if output is None:
output = torch.zeros((num_tokens, K),
device=fused_expert_output.device,
dtype=fused_expert_output.dtype)
else:
output.fill_(0)
assert output.size() == (num_tokens, K), (
f"Expected output size {(num_tokens, K)}, but got {output.size()}")
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
matching_tokens = topk_ids == expert_id
topks = torch.any(matching_tokens, dim=1).flatten()
rows = torch.count_nonzero(topks)
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
if not apply_router_weight_on_input:
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
output[topks] = output[topks] + rhs
return output

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Optional
from typing import Any, Optional, Union
import torch
@@ -10,7 +10,83 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
# from vllm.utils.flashinfer import fp4_quantize
@triton.jit
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
topk_numel, expert_map,
HAS_EXPERT_MAP: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
curr_expert = tl.program_id(0)
offsets = tl.arange(0, BLOCK_SIZE)
topk_ids_ptrs = topk_ids_ptr + offsets
acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
mask = offsets < (topk_numel - x * BLOCK_SIZE)
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
if HAS_EXPERT_MAP:
expert_map_ptrs = expert_map + expert_ids
expert_map_mask = expert_ids >= 0
expert_ids = tl.load(expert_map_ptrs,
mask=expert_map_mask,
other=-1)
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
acc = acc + has_curr_expert
topk_ids_ptrs += BLOCK_SIZE
if curr_expert < num_experts:
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
def count_expert_num_tokens(
topk_ids: torch.Tensor, num_local_experts: int,
expert_map: Optional[torch.Tensor]) -> torch.Tensor:
"""
Count the number to tokens assigned to each expert.
Parameters:
- topk_ids (torch.Tensor): Tensor mapping each token to its
list of experts.
- num_local_experts (int): Number of experts in this rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
Returns:
A tensor of size num_local_experts, where tensor[i] holds the number
of tokens assigned to the ith expert.
"""
assert topk_ids.dtype.is_signed, (
"The kernel uses -1 to represent invalid topk_ids")
expert_num_tokens = torch.empty((num_local_experts),
device=topk_ids.device,
dtype=torch.int32)
grid = num_local_experts
BLOCK_SIZE = min(topk_ids.numel(), 1024)
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
_count_expert_num_tokens[(grid, )](
topk_ids,
expert_num_tokens,
num_local_experts,
topk_ids.numel(),
expert_map,
HAS_EXPERT_MAP=expert_map is not None,
BLOCK_SIZE=BLOCK_SIZE,
)
return expert_num_tokens
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
@@ -23,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
return x.flatten()[:prod(v)].view(*v)
# def _fp4_quantize(
# A: torch.Tensor,
# A_scale: Optional[torch.Tensor],
# is_sf_swizzled_layout: bool,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# return fp4_quantize(A,
# A_scale,
# is_sf_swizzled_layout=is_sf_swizzled_layout)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
@@ -34,9 +120,12 @@ def _fp8_quantize(
is provided, the output will be blocked.
"""
if block_shape is None:
# TODO(luka): use QuantFP8 custom op
# https://github.com/vllm-project/vllm/issues/20711
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
@@ -62,9 +151,9 @@ def _int8_quantize(
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
# A, A_scale = per_token_quant_int8(A)
A, A_scale, _ = ops.scaled_int8_quant(A, A_scale)
A, A_scale = per_token_quant_int8(A)
else:
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
@@ -73,19 +162,40 @@ def _int8_quantize(
return A, A_scale
def _mxfp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, None]:
assert block_shape is None
if not current_platform.supports_mx():
A = quant_dequant_mxfp4(A)
else:
raise NotImplementedError()
return A, None
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
qtype: Optional[torch.dtype],
per_channel_quant: bool,
quant_dtype: Union[None, torch.dtype, str],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
is_fp4_scale_swizzled: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
elif qtype == torch.int8:
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4
return _fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
assert A_scale is None
return A, A_scale
@@ -97,3 +207,62 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def normalize_scales_shape(
scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
scales = scales.view(1, 1)
else:
scales = scales.view(-1, scales.size(-1))
return scales
def normalize_batched_scales_shape(
scales: Optional[torch.Tensor],
num_experts: int,
) -> Optional[torch.Tensor]:
if scales is not None and scales.ndim < 3:
if scales.numel() == 1:
scales = scales.view(1)
scales = torch.repeat_interleave(scales, num_experts,
dim=0).view(num_experts, 1, 1)
else:
scales = scales.view(num_experts, -1, scales.size(-1))
return scales
def _validate_scale_shape(
a: torch.Tensor,
a_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> None:
if a_scale is None:
return
if not per_act_token_quant and block_shape is None:
assert a_scale.numel() == 1, f"{a_scale.shape}"
elif per_act_token_quant:
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1")
else:
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)

View File

@@ -36,6 +36,7 @@ QuantizationMethods = Literal[
"moe_wna16",
"torchao",
"auto-round",
"mxfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
@@ -143,6 +145,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
"auto-round": AutoRoundConfig,
"mxfp4": Mxfp4Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@@ -0,0 +1,581 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase, fused_experts)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
_can_support_mxfp4, _swizzle_mxfp4)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up)
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
class Mxfp4Config(QuantizationConfig):
def __init__(self, ignored_layers: Optional[list[str]] = None):
super().__init__()
self.ignored_layers = ignored_layers
@classmethod
def from_config(cls, config):
return cls()
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError(
"Mxfp4 attention layer is not implemented")
return None
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.topk_indices_dtype = None
self.moe = moe
self.use_marlin = self._should_use_marlin()
def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN
# if current_platform.is_cuda() and \
# not current_platform.has_device_capability(100):
# if not current_platform.is_device_capability(90):
# # marlin kernel has better performance on ampere
# return True
# if not has_triton_kernels():
# return True
# if not is_torch_equal_or_newer("2.8.0"):
# return True
return False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
# FIXME (zyongye): ship after torch and safetensors support mxfp4
# is_torch_mxfp4_available = (
# hasattr(torch, "float4_e2m1fn_x2") and
# hasattr(torch, "float8_e8m0fnu"))
# if is_torch_mxfp4_available:
# weight_dtype = torch.float4_e2m1fn_x2
# scale_dtype = torch.float8_e8m0fnu
mxfp4_block = 32
intermediate_size_per_partition_after_pad = \
intermediate_size_per_partition
if self.use_marlin:
# The moe marlin kernel requires that for each linear
# n % 256 == 0 and k % 128 == 0.
# In gate_up_proj:
# n = 2 * intermediate_size_per_partition_after_pad
# k = hidden_size
# In down_proj
# n = hidden_size
# k = intermediate_size_per_partition_after_pad
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256)
hidden_size = round_up(hidden_size, 256)
elif current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64)
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
dtype=torch.bfloat16,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_beta = Parameter(torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
layer.gemm1_clamp_limit = Parameter(torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
sf_block_size = 32 # mxfp4 block size
assert (layer.w13_weight.dim() == 3
and layer.w13_weight.shape[0] == self.num_experts
and layer.w13_weight.shape[1] == self.intermediate_size * 2
and layer.w13_weight.shape[2] == self.hidden_size // 2)
assert (layer.w13_weight_scale.dim() == 3
and layer.w13_weight_scale.shape[0] == self.num_experts
and layer.w13_weight_scale.shape[1]
== self.intermediate_size * 2
and layer.w13_weight_scale.shape[2]
== self.hidden_size // sf_block_size)
assert (layer.w2_weight.dim() == 3
and layer.w2_weight.shape[0] == self.num_experts
and layer.w2_weight.shape[1] == self.hidden_size and
layer.w2_weight.shape[2] == self.intermediate_size // 2)
assert (layer.w2_weight_scale.dim() == 3
and layer.w2_weight_scale.shape[1] == self.hidden_size
and layer.w2_weight_scale.shape[2]
== self.intermediate_size // sf_block_size)
assert (layer.w13_bias.dim() == 2
and layer.w13_bias.shape[0] == self.num_experts
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
assert (layer.w2_bias.dim() == 2
and layer.w2_bias.shape[0] == self.num_experts
and layer.w2_bias.shape[1] == self.hidden_size)
w13_weight_scale = layer.w13_weight_scale.data
w2_weight_scale = layer.w2_weight_scale.data
w13_weight = layer.w13_weight.data
w2_weight = layer.w2_weight.data
w13_bias = layer.w13_bias.data.to(torch.float32)
w2_bias = layer.w2_bias.data.to(torch.float32)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
def swap_every_two_rows(x, axis=-1):
shape = x.shape
if axis < 0:
axis = len(shape) + axis
# Create a new shape with pairs swapped along specified axis
new_shape = list(shape)
new_shape[axis] = shape[axis] // 2
new_shape.insert(axis + 1, 2)
# Reshape to expose pairs, swap them, and reshape back
x = x.reshape(*new_shape)
x = x.flip(axis + 1)
new_shape = list(shape)
return x.reshape(*new_shape)
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
w13_weight = swap_every_two_rows(w13_weight, -2)
w13_bias = swap_every_two_rows(w13_bias, -1)
# Do not interleave as the checkpoint is already interleaved
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled = []
gemm1_scales_mxfp4_shuffled = []
gemm2_weights_mxfp4_shuffled = []
gemm2_scales_mxfp4_shuffled = []
gemm1_bias_shuffled = []
gemm2_bias_shuffled = []
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
for i in range(self.num_experts):
gemm1_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm1_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm1_bias_shuffled.append(
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
gemm2_weights_mxfp4_shuffled.append(
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
epilogue_tile_m))
gemm2_scales_mxfp4_shuffled.append(
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
epilogue_tile_m))
gemm2_bias_shuffled.append(
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
epilogue_tile_m))
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
w13_weight_scale = torch.stack(
gemm1_scales_mxfp4_shuffled).reshape(
self.num_experts, 2 * self.intermediate_size,
self.hidden_size // sf_block_size).view(
torch.float8_e4m3fn)
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
self.num_experts, self.hidden_size, self.intermediate_size //
sf_block_size).view(torch.float8_e4m3fn)
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(w13_weight_scale,
requires_grad=False)
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_weight_scale,
requires_grad=False)
layer.w13_bias = Parameter(
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
requires_grad=False)
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
self.num_experts, -1),
requires_grad=False)
elif has_triton_kernels():
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex))
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex))
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
else:
# normal triton
from .triton_kernels_numerics_details.mxfp import upcast_from_mxfp
w13_weight = upcast_from_mxfp(
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
)
w2_weight = upcast_from_mxfp(
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
)
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w2_weight_scale
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4")
if self.use_marlin:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
layer.w13_bias,
layer.w2_bias,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
global_scale1=None,
global_scale2=None,
quant_type_id=scalar_types.float4_e2m1f.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,
apply_router_weight_on_input, scoring_func, activation,
expert_load_view, logical_to_physical_map,
logical_replica_count), (
"MXFP4 are not supported with this configuration.")
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
else:
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
None, # routing_bias
x_quant,
x_scale,
layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
self.num_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
1 if renormalize else 0, # routing_method_type, renormalize
True, # do finalize
)[0]
return trtllm_gen_output
elif has_triton_kernels():
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
w1_precision=self.w13_precision_config,
w2_precision=self.w2_precision_config,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
)

View File

@@ -6,14 +6,16 @@ from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["QuarkW4A4MXFP4"]
@@ -25,7 +27,29 @@ class QuarkW4A4MXFP4(QuarkScheme):
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.emulate = not current_platform.supports_mx()
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
@classmethod
def get_min_capability(cls) -> int:
@@ -37,43 +61,6 @@ class QuarkW4A4MXFP4(QuarkScheme):
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
if not envs.VLLM_QUARK_EMU_MEM_OPT:
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
else:
self.weight_quantizer = weight_quantizer
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
@@ -116,11 +103,10 @@ class QuarkW4A4MXFP4(QuarkScheme):
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.emulate:
if envs.VLLM_QUARK_EMU_MEM_OPT:
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
else:
dq_w = layer.weight
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
return F.linear(qdq_x, dq_w, bias)
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()

View File

@@ -0,0 +1,158 @@
import triton
import triton.language as tl
# fmt: off
MXFP_BLOCK_SIZE = tl.constexpr(32)
@triton.jit
def _get_max_quant_val(dtype: tl.constexpr):
if dtype == tl.uint8:
return 6.0
elif dtype == tl.float8e5:
return 57344.0
elif dtype == tl.float8e4nv:
return 448.0
else:
tl.static_assert(False, f"Invalid {dtype=}")
@triton.jit
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor = src_tensor.to(tl.float32)
abs_tensor = tl.abs(f32_tensor)
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
if DEQUANT_SCALE_ROUNDING_MODE == 0:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
else:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert DEQUANT_SCALE_ROUNDING_MODE == 1
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
quant_tensor = f32_tensor * quant_scale
# Reshape the tensors after scaling
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
# Now we must convert the tensors to the mx format.
if is_fp8:
out_tensor = quant_tensor.to(mx_tensor_dtype)
else:
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
signs = quant_tensor & 0x80000000
exponents = (quant_tensor >> 23) & 0xFF
mantissas = (quant_tensor & 0x7FFFFF)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS = 127
E2_BIAS = 1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
evens, odds = tl.split(e2m1_value)
out_tensor = evens | (odds << 4)
return out_tensor, dequant_scale_exponent
@triton.jit
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
src_ptr, stride_src_outer, stride_src_quant,
outer_dim, quant_dim,
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out = outer_block * BLOCK_SIZE_OUT_DIM
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
mask_n = start_out + offs_outer < outer_dim
full_mask_src = mask_src_quant & mask_n
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_mxt = mask_mxt_quant & mask_n
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = scale_mask_k & mask_n
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
DEQUANT_SCALE_ROUNDING_MODE)
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
def _dequantize_mxfp8_fn(input, mask, pid=None):
return _compute_quant_and_scale(input, mask, tl.float8e4nv)

View File

@@ -0,0 +1,136 @@
import triton
import triton.language as tl
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
# fmt: off
@triton.jit
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
tl.static_assert(dst_dtype == tl.float16 or (dst_dtype == tl.bfloat16 or dst_dtype == tl.float32))
tl.static_assert(
mx_tensor_dtype == tl.uint8
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
# Determine if we are dealing with fp8 types.
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_out = outer_block * BLOCK_SIZE_OUT_DIM
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
# Compute offsets and masks.
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
mask_outer = start_out + offs_outer < outer_dim
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
full_mask_out = mask_out_quant & mask_outer
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_src = mask_src_quant & mask_outer
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = mask_scale & mask_outer
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
# Load the packed tensor and scale.
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
# Upcast the scale to the destination type.
if dst_dtype == tl.bfloat16:
# dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
dst_scale = (scale.to(tl.uint16) << 7).to(tl.uint16).to(tl.bfloat16, bitcast=True)
else:
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
if dst_dtype == tl.float16:
dst_scale = dst_scale.to(tl.float16)
# Now upcast the tensor.
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
if is_fp8:
dst_tensor = tensor.to(intermediate_dtype)
if tensor.dtype == tl.float8e5:
from_e_bits: tl.constexpr = 5
from_m_bits: tl.constexpr = 2
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
dst_tensor = tl.where(
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
dst_tensor,
)
else:
assert is_fp4
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# e2m1
em0 = tensor & 0x07
em1 = tensor & 0x70
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
# 3) x is zero, do nothing
if intermediate_dtype == tl.bfloat16:
dst_tensor = tl.interleave(x0, x1).to(tl.uint16).to(tl.bfloat16, bitcast=True)
else:
dst_tensor = tl.interleave(x0, x1).to(tl.float16, bitcast=True)
# dst_tensor = dst_tensor.to(dst_dtype)
if dst_dtype == tl.bfloat16:
dst_tensor = dst_tensor.to(tl.bfloat16)
elif dst_dtype == tl.float16:
dst_tensor = dst_tensor.to(tl.float16)
else:
dst_tensor = dst_tensor.to(tl.float32)
# Reshape for proper broadcasting: the scale was stored with a 32sized “inner” grouping.
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
scale = scale.reshape(dst_scale.shape)
out_tensor = dst_tensor * dst_scale
# Correct any NaNs encoded via the scale.
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)

View File

@@ -0,0 +1,303 @@
# isort: off
# fmt: off
from enum import Enum
import triton
import torch
import torch.nn.functional as F
from ._upcast_from_mxfp import _upcast_from_mxfp
from ._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
# -----------------------------------------------------------------------------
# Dequantization / Quantization Utilities
# -----------------------------------------------------------------------------
class DequantScaleRoundingMode(Enum):
ROUND_UP = 0
ROUND_DOWN = 1
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
Note that this means the k_dim of the tensor will be half of the logical k_dim.
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
in their respective formats.
"""
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
# downcast
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
is_fp4 = out_quant_type == torch.uint8
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
assert is_fp4 or is_fp8
divisor = 2 if is_fp4 else 1
L = src_tensor.shape[-1]
if is_fp4:
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
out_shape = src_tensor.shape[:-1] + (L // divisor, )
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
if src_tensor.numel() > 0:
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
return out_quant_tensor, out_scale
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
"""
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
The function assumes that the tensors were quantized along the given axis.
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
launches the Triton upcast kernel, and then unpermutes back to the original order.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
f"Got {tensor.ndim=} and {scale.ndim=}")
# dtype checks
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
f"Invalid tensor dtype {tensor.dtype=}"
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
assert dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {dtype=}"
# upcast
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
reshaped_out = out.view(-1, out.shape[-1])
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
reshaped_scale = scale.view(-1, scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
*reshaped_scale.stride(), reshaped_tensor,
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
BLOCK_QUANT_DIM, num_warps=8)
out = out.transpose(axis, scale.ndim - 1).contiguous()
return out
# ------------
def right_shift_unsigned(x, shift):
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
return (x >> shift) & ((1 << (32 - shift)) - 1)
def get_max_quant_val(dtype: torch.dtype):
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
assert dtype in d
return d[dtype]
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Converts the src tensor to the output format specified by out_quant_type.
axis: The axis along which the tensors are contiguous and quantization is applied.
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
Returns:
out_quant_tensor: Quantized tensor in mx format.
• For mxfp8, the output has the same shape as src_tensor.
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
where L is the original length along that axis.
"""
# This should probably be packed into its own tiny class
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
axis = axis if axis >= 0 else axis + ndim
is_fp4 = out_quant_type == torch.uint8
is_fp8 = "float8" in str(out_quant_type)
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
device = src_tensor.device
# For mxfp4 conversion, we assume the contiguous axis length is even.
if is_fp4:
axis_shape = src_tensor.size(axis)
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
# Permute the tensor so that the contiguous axis becomes the last dimension.
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
axis_shape = src.shape[-1]
# Pad the axis to be divisible by 32, in case it is not.
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_amount = next_multiple - axis_shape
padded_src = F.pad(src, (0, pad_amount))
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
padded_axis_shape = padded_src.size(-1) # now divisible by 32
# --- Compute per-group maximums for scale ---
# Set padded entries to -1 so they dont affect the max.
abs_f = torch.abs(padded_src)
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
# Reshape the last dimension into groups of 32.
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
abs_groups = abs_f.view(*new_shape)
# Compute maximum along the group dimension (of size 32).
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
# Choose a max quantization value depending on type.
max_quant_val = get_max_quant_val(out_quant_type)
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
# Convert to int to round the FP32 scale, prior to quantization!
ds_int = dequant_scale.view(torch.int32)
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
else:
ds_int_rounded = ds_int & 0x7F800000
# Reinterpret back as float32.
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
# Compute the quantization scale.
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
# Quantize the tensor
orig_padded_shape = padded_src.shape
padded_src_groups = padded_src.view(*new_shape)
quant_tensor = padded_src_groups * quant_scale
# Reshape back to the original shape and trim padding
quant_tensor = quant_tensor.view(orig_padded_shape)
quant_tensor = quant_tensor[..., :axis_shape]
# Finally, convert the quantized tensor to the target format
if is_fp8:
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
out_weight = quant_tensor.to(out_quant_type)
else:
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
# First, reinterpret the quantized tensor bits.
q_int = quant_tensor.contiguous().view(torch.int32)
# Extract sign, exponent, and mantissa.
signs = q_int & 0x80000000
exponents = right_shift_unsigned(q_int, 23) & 0xFF
mantissas = q_int & 0x7FFFFF
E8_BIAS = 127
E2_BIAS = 1
# Adjust mantissas for subnormals.
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
(E8_BIAS - exponents - 1), mantissas)
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
# Pack pairs of 4-bit values along the last dimension.
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
evens = e2m1_value[..., 0]
odds = e2m1_value[..., 1]
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
# --- Process and output the scale ---
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
dq_scale = dq_scale.squeeze(-1)
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
return out_weight, dq_scale
def cvt_e2m1_to_fp32(input_tensor):
assert input_tensor.dtype == torch.uint8
input_tensor = input_tensor.to(torch.int32)
evens = input_tensor & 0xF
odds = (input_tensor >> 4) & 0xF
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
outputs = torch.cat([outputs, -outputs])
even_floats = outputs[evens]
odd_floats = outputs[odds]
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
return output_tensor
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
"""
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
axis: The axis along which dequantization is applied.
Returns:
out_weight: Tensor in the target format.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
# Permute the tensor and scale so that the quantization axis becomes the last dimension
axis = axis if axis >= 0 else axis + ndim
scale = scale.transpose(axis, scale.ndim - 1)
tensor = tensor.transpose(axis, tensor.ndim - 1)
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
if tensor.dtype == torch.uint8:
fp32_tensor = cvt_e2m1_to_fp32(tensor)
else:
fp32_tensor = tensor.to(torch.float32)
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
axis_shape = fp32_tensor.size(-1)
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_size = padded_axis_shape - axis_shape
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
new_axis_shape = padded_tensor.shape[-1]
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
padded_tensor = padded_tensor.view(*new_shape)
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
out_padded = padded_tensor * dq_scale_padded
# Flatten back and remove the padded tail
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
out_tensor = out_padded[..., :axis_shape]
out_tensor = out_tensor.to(target_dtype).contiguous()
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
return out_tensor
dequantize_mxfp8_fn = _dequantize_mxfp8_fn

View File

@@ -261,6 +261,13 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s
def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor:
origin_shape = s.shape
_, scale_perm_single = get_scale_perms()
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
return s.reshape(*origin_shape).contiguous()
def marlin_moe_permute_scales(
s: torch.Tensor,
size_k: int,
@@ -410,6 +417,7 @@ def apply_gptq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -425,9 +433,6 @@ def apply_gptq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -456,6 +461,7 @@ def apply_awq_marlin_linear(
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
bias,
weight_scale,
None,
weight_zp,
@@ -470,7 +476,4 @@ def apply_awq_marlin_linear(
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -8,8 +8,8 @@ import torch
import vllm._custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales,
should_use_atomic_add_reduce)
USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias,
marlin_permute_scales, should_use_atomic_add_reduce)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -22,7 +22,7 @@ def is_fp4_marlin_supported():
return current_platform.has_device_capability(80)
def fp4_marlin_process_scales(marlin_scales):
def nvfp4_marlin_process_scales(marlin_scales):
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
@@ -56,7 +56,20 @@ def fp4_marlin_process_scales(marlin_scales):
return marlin_scales
def fp4_marlin_process_global_scale(global_scale):
def mxfp4_marlin_process_scales(marlin_scales):
# 8 is the number of scale number using by one thread
marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8)
marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape(
marlin_scales.size(0) * 2, -1)
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
return marlin_scales
def nvfp4_marlin_process_global_scale(global_scale):
assert global_scale.dtype in [torch.half, torch.bfloat16]
fp4_exponent = 2
if global_scale.dtype == torch.half:
@@ -73,7 +86,7 @@ def apply_fp4_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
weight_scale_2: Optional[torch.Tensor],
workspace: torch.Tensor,
size_n: int,
size_k: int,
@@ -94,6 +107,7 @@ def apply_fp4_marlin_linear(
output = ops.gptq_marlin_gemm(a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
global_scale=weight_scale_2,
b_zeros=None,
@@ -107,9 +121,6 @@ def apply_fp4_marlin_linear(
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
@@ -120,6 +131,9 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "weight_scale_2")
group_size = 16 if is_nvfp4 else 32
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
param_dtype = layer.params_dtype
@@ -145,18 +159,35 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
weight_scale = layer.weight_scale.T.to(param_dtype)
weight_scale = layer.weight_scale.T.contiguous()
if not is_nvfp4:
weight_scale = weight_scale.view(torch.float8_e8m0fnu)
weight_scale = weight_scale.to(param_dtype)
weight_scale = marlin_permute_scales(s=weight_scale,
size_k=part_size_k,
size_n=part_size_n,
group_size=16)
weight_scale = fp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
group_size=group_size)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
if is_nvfp4:
weight_scale = nvfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
weight_scale_2 = layer.weight_scale_2.to(param_dtype)
weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2)
layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2,
requires_grad=False)
else:
weight_scale = mxfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale,
requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n, )
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
return
@@ -168,6 +199,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
is_nvfp4 = hasattr(layer, "w13_weight_scale_2")
group_size = 16 if is_nvfp4 else 32
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
@@ -208,8 +242,13 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Permute scales
for name in ["w13", "w2"]:
scales = getattr(layer, name + "_weight_scale").to(param_dtype)
global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype)
scales = getattr(layer, name + "_weight_scale")
if not is_nvfp4:
scales = scales.view(torch.float8_e8m0fnu)
scales = scales.to(param_dtype)
if is_nvfp4:
global_scale = getattr(layer,
name + "_weight_scale_2").to(param_dtype)
tensor_list = []
if "w13" in name:
@@ -218,23 +257,47 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None:
size_n, size_k = k, n
for i in range(e):
marlin_scales = marlin_permute_scales(s=scales[i].T,
scale = scales[i].T
marlin_scales = marlin_permute_scales(s=scale,
size_k=size_k,
size_n=size_n,
group_size=16)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
group_size=group_size)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
else:
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
if is_nvfp4:
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = torch.nn.Parameter(global_scale,
requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
# BIAS
# Permute bias
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(param_dtype)
tensor_list = []
for i in range(e):
expert_bias = bias[i]
tensor_list.append(marlin_permute_bias(expert_bias))
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
def rand_marlin_weight_fp4_like(weight, group_size):
def rand_marlin_weight_nvfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
@@ -276,8 +339,58 @@ def rand_marlin_weight_fp4_like(weight, group_size):
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = fp4_marlin_process_scales(marlin_scales)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
global_scale = fp4_marlin_process_global_scale(global_scale)
global_scale = nvfp4_marlin_process_global_scale(global_scale)
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
def rand_marlin_weight_mxfp4_like(weight, group_size):
assert group_size > 0
size_n, size_k = weight.shape
device = weight.device
scales = torch.randint(100,
125, (size_n, size_k // group_size),
dtype=torch.uint8,
device=weight.device)
scales = scales.view(torch.float8_e8m0fnu)
fp4_weight = torch.randint(0,
256, (size_n, size_k // 2),
dtype=torch.uint8,
device=weight.device)
fp4_weight_part_1 = ((fp4_weight & 0b10000000) |
((fp4_weight & 0b01110000) >> 2))
fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn)
fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6)
fp4_weight2 = fp4_weight << 4
fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) |
((fp4_weight2 & 0b01110000) >> 2))
fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn)
fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6)
weight_ref = torch.cat(
[fp4_weight_part_2.unsqueeze(2),
fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k)
weight_ref = weight_ref * \
scales.repeat_interleave(group_size, 1).to(weight.dtype)
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=fp4_weight.view(torch.int32).T.contiguous(),
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=4,
)
marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype),
size_k=size_k,
size_n=size_n,
group_size=group_size)
marlin_scales = mxfp4_marlin_process_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales.to(torch.float8_e8m0fnu)

View File

@@ -1,45 +1,133 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)
OCP_MX_BLOCK_SIZE = 32
def per_token_group_quant_mxfp4(x: torch.Tensor,
block_k: int,
scale_calculation_mode: str = "even"
) -> tuple[torch.Tensor, torch.Tensor]:
def _swizzle_mxfp4(quant_tensor, scale, num_warps):
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
"""
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.numerics import InFlexData
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
if (current_platform.is_cuda()
and current_platform.is_device_capability(90)
and not is_torch_equal_or_newer("2.8.1")):
logger.warning_once(
"Mxfp4 on hopper is running on torch < 2.8.1, "
"this cause swizling to be disabled, which may "
"cause performance degradation. Please upgrade to torch nightly")
value_layout, value_layout_opts = StridedLayout, dict()
scale_layout, scale_layout_opts = StridedLayout, dict()
else:
value_layout, value_layout_opts = \
layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
scale_layout, scale_layout_opts = (
layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=1, num_warps=num_warps))
if current_platform.is_cuda() and \
current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1)
quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4),
value_layout, **value_layout_opts)
scale = convert_layout(wrap_torch_tensor(scale), scale_layout,
**scale_layout_opts)
return quant_tensor, InFlexData(), scale
def _can_support_mxfp4(use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
scoring_func: str = "softmax",
activation: str = "swiglu_oai",
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swiglu_oai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
try:
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
fake_quantize_fp4_fp6_per_group_with_scale)
from quark.torch.quantization.utils import (even_round,
reshape_to_blocks)
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
axis = -1
block_x = reshape_to_blocks(x, block_k, axis)
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
amax = amax.squeeze(-1)
return mx.dq_mxfp4(x, scale, float_dtype)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if scale_calculation_mode != "even":
raise NotImplementedError(
f"Scale calculation mode {scale_calculation_mode} is not yet "
"supported in MX-FP4 quantization")
scale = even_round(amax, "fp4")
# Apply dequantize(quantize(x)).
x = fake_quantize_fp4_fp6_per_group_with_scale(
x,
scale.to(x.device),
axis=axis,
group_size=block_k,
quant_dtype="fp4",
def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor,
float_dtype: torch.dtype) -> torch.Tensor:
return torch.empty((*x.shape[:-1], x.shape[-1] * 2),
dtype=float_dtype,
device=x.device)
def _quant_dequant_mxfp4(x: torch.Tensor,
scale_calculation_mode: str = "even") -> torch.Tensor:
try:
from quark.torch.kernel import mx
except ImportError as err:
raise ImportError("The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
return mx.qdq_mxfp4(x, scale_calculation_mode)
def _quant_dequant_mxfp4_fake(x: torch.Tensor,
scale_calculation_mode: str = "even"
) -> torch.Tensor:
return torch.empty_like(x)
try:
direct_register_custom_op(
op_name="dequant_mxfp4",
op_func=_dequant_mxfp4,
mutates_args=[],
fake_impl=_dequant_mxfp4_fake,
)
dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
except AttributeError as error:
raise error
return x, scale
try:
direct_register_custom_op(
op_name="quant_dequant_mxfp4",
op_func=_quant_dequant_mxfp4,
mutates_args=[],
fake_impl=_quant_dequant_mxfp4_fake,
)
quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
except AttributeError as error:
raise error

View File

@@ -3,22 +3,41 @@
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from types import MappingProxyType
from typing import Optional
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.model_executor.layers.quantization.qqq import (
MARLIN_QQQ_SUPPORTED_NUM_BITS)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
col: int
class GroupShape(_GroupShape):
"""
This class describes the quantization group shape.
It includes static members for common shapes (per-tensor, per-token).
"""
# Aliases for common quantization group shapes
PER_TENSOR: ClassVar['GroupShape']
PER_TOKEN: ClassVar['GroupShape']
GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int,
int]):
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
@@ -58,7 +77,7 @@ def group_broadcast(t, shape):
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: tuple[int, int],
group_shape: GroupShape,
quant_dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
@@ -99,7 +118,7 @@ def scaled_quantize(
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[tuple[int, int]] = None,
group_shape: Optional[GroupShape] = None,
out_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
@@ -332,6 +351,10 @@ def quantize_weights(w: torch.Tensor,
)
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
@@ -571,3 +594,56 @@ def awq_pack(
q_w = q_w.reshape((-1, size_n)).contiguous()
return pack_cols(q_w, num_bits, size_k, size_n)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert scale.dtype == torch.float8_e4m3fn, (
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format.")
scale_ndim = scale.ndim
if scale_ndim == 2:
scale = scale.unsqueeze(0) # (1, M, K)
assert scale.ndim == 3, "Expected a 2-D or 3-D tensor for block scales."
B, M, K = scale.shape
def _round_up(x: int, m: int) -> int:
return (x + m - 1) // m * m
M_padded = _round_up(M, 128)
K_padded = _round_up(K, 4)
padded = torch.zeros((B, M_padded, K_padded),
dtype=scale.dtype,
device=scale.device)
padded[:B, :M, :K] = scale
# Reshape / permute to the layout required by the kernel.
padded = padded.reshape(B, M_padded // 128, 4, 32, K_padded // 4, 4)
swizzled = padded.permute(0, 1, 4, 3, 2, 5).contiguous().cuda()
if scale_ndim == 2:
return swizzled.reshape(M, K)
return swizzled.reshape(B, M, K)
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)

View File

@@ -8,7 +8,6 @@ import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
from vllm import envs
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -28,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from .utils import extract_layer_index, maybe_prefix
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
maybe_prefix)
class OAIAttention(nn.Module):
@@ -70,12 +70,9 @@ class OAIAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size()
# attention_sink_dtype = (torch.float32 if envs.VLLM_USE_TRTLLM_ATTENTION
# else torch.bfloat16)
attention_sink_dtype = torch.bfloat16
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,
dtype=torch.bfloat16,
requires_grad=False))
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
@@ -207,6 +204,7 @@ class GptOssModel(nn.Module):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
self.config.vocab_size,
@@ -229,8 +227,364 @@ class GptOssModel(nn.Module):
x = self.norm(x)
return x
def _load_weights_mxfp4(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights scale
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_weight_scale" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w13_weight" in name:
# Handle MLP gate and up projection weights
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=name,
shard_id=None,
expert_id=None)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params
def _load_weights_other(
self,
ep_rank_start: int,
ep_rank_end: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
for name, weight in weights:
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv", ".q_proj", "q"),
(".qkv", ".k_proj", "k"),
(".qkv", ".v_proj", "v"),
]
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# Attention heads per rank
heads_per_rank = self.config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
quant_method = (self.config.quantization_config['quant_method'] if
hasattr(self.config, "quantization_config") else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
else:
return self._load_weights_other(ep_rank_end, ep_rank_start,
heads_per_rank, head_start,
weights, stacked_params_mapping)
class GptOssForCausalLM(nn.Module):
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".self_attn.": ".attn.",
".post_attention_layernorm.": ".mlp.norm.",
},
orig_to_new_suffix={
".embed_tokens.weight": ".embedding.weight",
".input_layernorm.weight": ".attn.norm.weight",
".post_attention_layernorm.weight": ".mlp.norm.weight",
# MoE MXFP4 weights
".gate_up_proj_blocks": ".w13_weight",
".down_proj_blocks": ".w2_weight",
".gate_up_proj_scales": ".w13_weight_scale",
".down_proj_scales": ".w2_weight_scale",
# MoE other weights
".gate_up_proj": ".w13_weight",
".down_proj": ".w2_weight",
# MoE Bias
".gate_up_proj_bias": ".w13_bias",
".down_proj_bias": ".w2_bias",
},
)
def __init__(
self,
@@ -239,16 +593,17 @@ class GptOssForCausalLM(nn.Module):
):
super().__init__()
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config.hf_config
self.config = vllm_config.model_config.hf_config
self.model = GptOssModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.lm_head = ParallelLMHead(
self.model_config.vocab_size,
self.model_config.hidden_size,
self.config.vocab_size,
self.config.hidden_size,
)
self.logits_processor = LogitsProcessor(self.model_config.vocab_size)
self.logits_processor = LogitsProcessor(self.config.vocab_size)
def forward(self,
input_ids: torch.Tensor,
@@ -265,354 +620,11 @@ class GptOssForCausalLM(nn.Module):
sampling_metadata)
return logits
def _load_weights_mxfp4(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = cdiv(intermediate_size_block,
tp_size)
per_rank_intermediate_size = (per_rank_intermediate_size_block *
mxfp4_block)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(num_experts, 2 * intermediate_size,
-1).contiguous()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(num_experts, -1,
intermediate_size // 2).contiguous()
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[...,
tp_rank_start // 2:tp_rank_end // 2]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales",
"w13_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end,
...]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[..., tp_rank_start //
mxfp4_block:tp_rank_end //
mxfp4_block]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
weight_loader(param,
weight,
weight_name=new_name,
shard_id=None,
expert_id=None)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def _load_weights_other(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rename_mapping = {
"self_attn": "attn",
"input_layernorm.weight": "attn.norm.weight",
"post_attention_layernorm.weight": "mlp.norm.weight",
"embed_tokens": "embedding",
}
def maybe_rename(name: str) -> str:
for remap_name, new_name in rename_mapping.items():
if remap_name in name:
return name.replace(remap_name, new_name)
return name
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.model_config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
# Attention heads per rank
heads_per_rank = self.model_config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
ep_size = get_ep_group().world_size
ep_rank = get_ep_group().rank
num_experts = self.model_config.num_local_experts
experts_per_rank = num_experts // ep_size
ep_rank_start = ep_rank * experts_per_rank
ep_rank_end = (ep_rank + 1) * experts_per_rank
for name, weight in weights:
if ".experts.gate_up_proj" in name and "bias" not in name:
# Handle MLP gate and up projection weights
new_name = name.replace(".experts.gate_up_proj",
".experts.w13_weight")
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :,
2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif ".experts.down_proj" in name and "bias" not in name:
# Handle MLP down projection weights
new_name = name.replace(".experts.down_proj",
".experts.w2_weight")
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_bias")
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:,
2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[new_name]
param.copy_(narrow_weight)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_bias")
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[new_name]
param.copy_(weight)
loaded_params.add(new_name)
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
shard_id = ("q" if "q_proj" in name else
"k" if "k_proj" in name else "v")
name = name.replace("self_attn", "attn")
param_name = name.replace(f"{shard_id}_proj", "qkv")
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, weight, loaded_shard_id=shard_id)
loaded_params.add(param_name)
else:
# Handle all other weights with potential renaming
renamed_name = maybe_rename(name)
if renamed_name not in params_dict:
continue
param = params_dict[renamed_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(renamed_name)
return loaded_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
quant_method = (self.model_config.quantization_config['quant_method']
if hasattr(self.model_config, "quantization_config")
else None)
if quant_method == "mxfp4":
return self._load_weights_mxfp4(weights)
else:
return self._load_weights_other(weights)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)