Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
self.dynamic_mxfp4_quant = False
def maybe_update_config(self, model_name: str, revision: str | None = None):
self.hf_config = get_config(
model=model_name,
trust_remote_code=False, # or get from model_config if available
revision=revision,
config_format="auto",
)
quant_config = getattr(self.hf_config, "quantization_config", None)
if quant_config is not None:
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
model_type = self.hf_config.model_type
if quant_dtype == "fp4" and model_type == "deepseek_v3":
self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if (
"self_attn" not in prefix # only quantize attention projections
or not getattr(self, "dynamic_mxfp4_quant", False)
or not isinstance(layer, LinearBase) # Ignore other methods
):
return UnquantizedLinearMethod()
scheme = self.get_scheme(
layer=layer,
layer_name=prefix,
dynamic_mxfp4_quant=True,
)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
)
return global_quant_config
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
def _get_scheme_from_config(
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config)
return QuarkOCP_MX(
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
raise NotImplementedError(
"No quark compatible scheme was found. "
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
f"Input config: {input_config}"
)
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
def get_scheme(
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
scheme = self._get_scheme_from_config(
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())

View File

@@ -5,8 +5,8 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
__all__ = [
"QuarkMoEMethod",
"QuarkOCP_MX_MoEMethod",
"QuarkOCP_MX_MoEMethod_OSS",
]
class QuarkMoEMethod(FusedMoEMethodBase):
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
"output_tensors and bias "
"quantized are not supported"
)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
emulate = not current_platform.supports_mx() or not (
rocm_aiter_ops.is_fused_moe_enabled()
)
if (
input_config.get("dtype") == "fp8_e4m3"
and not input_config.get("is_dynamic")
and not emulate
):
return QuarkOCP_MX_MoEMethod_OSS(
weight_config, input_config, module.moe_config
)
else:
return QuarkOCP_MX_MoEMethod(
weight_config, input_config, module.moe_config
)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
self._emulate = (
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
self.unpadded_hidden_size = extra_weight_attrs.get(
"unpadded_hidden_size", hidden_size
)
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
if (
self.model_type == "gpt_oss"
and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
else:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
return rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(weight_config, input_config, moe)
def process_weights_after_loading(self, layer):
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w13_bias = layer.w13_bias.to(torch.float32)
w2_bias = layer.w2_bias.to(torch.float32)
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
# FIXME warp need to be adjusted based on batch size
# only apply to batched mode
if self.moe.use_ep:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
torch.cuda.empty_cache()
if self.static_input_scales:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer."
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
)
from triton_kernels.numerics import InFlexData
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale,
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale,
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return mxfp4_w4a8_moe_quant_config(
w1_scale=self.w13_precision_config,
w2_scale=self.w2_precision_config,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
@property
def is_monolithic(self) -> bool:
return True
def apply_monolithic(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
)
return triton_kernel_moe_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
w2=self.w2_weight_triton_tensor,
gating_output=router_logits,
topk=layer.top_k,
renormalize=layer.renormalize,
global_num_experts=layer.global_num_experts,
expert_map=expert_map,
quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
unpadded_N_w1=self.intermediate_size_per_partition * 2,
unpadded_K_w1=self.unpadded_hidden_size,
unpadded_N_w2=self.unpadded_hidden_size,
unpadded_K_w2=self.intermediate_size_per_partition,
)

View File

@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PackedvLLMParameter,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
class QuarkOCP_MX(QuarkScheme):
def __init__(
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
self,
weight_quant_spec: dict[str, Any],
input_quant_spec: dict[str, Any],
dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
layer.weight_scale.data, requires_grad=False
)
else:
if self.rocm_use_aiter_fp4_asm_gemm:
if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
if self.dynamic_mxfp4_quant:
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, kwargs)
else:
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# WEIGHT
weight = PackedvLLMParameter(
data=torch.empty(
output_size_per_partition,
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.packed_factor,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // OCP_MX_BLOCK_SIZE,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,