Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -32,6 +32,7 @@ from .deepseek_v2 import (
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name,
)
from .interfaces import SupportsPP
from .utils import maybe_prefix
logger = init_logger(__name__)
@@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
@@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
# Validate that weights were loaded for each expected MTP layer.
loaded_layers: set[int] = set()
for param_name in loaded_params:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
if spec_layer is not None:
loaded_layers.add(spec_layer)
for layer_idx in range(
self.model.mtp_start_layer_idx,
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
):
if layer_idx not in loaded_layers:
raise ValueError(
f"MTP speculative decoding layer {layer_idx} weights "
f"missing from checkpoint. The checkpoint may have "
f"been quantized without including the MTP layers. "
f"Use a checkpoint that includes MTP layer weights, "
f"or disable speculative decoding."
)
# Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa
# into a single GEMM, then monkey-patch forward to forward_opt.
# Same logic as DeepseekV2ForCausalLM.load_weights.
opt_support_quant_method = [
"GGUFLinearMethod", "UnquantizedLinearMethod",
"CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod",
]
def inject_layer(layer, quant_method, is_mla):
logger.info(
"DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.",
layer.__class__.__name__, quant_method, is_mla)
q_lora_rank = getattr(layer, "q_lora_rank", None)
if quant_method in ["UnquantizedLinearMethod",
"CompressedTensorsW8A8Int8"]:
if q_lora_rank is not None:
layer.q_a_proj.weight.data = torch.cat(
[layer.q_a_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_a_proj, "weight_scale"):
layer.q_a_proj.weight_scale.data = torch.cat(
[layer.q_a_proj.weight_scale,
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
elif not is_mla:
layer.q_proj.weight.data = torch.cat(
[layer.q_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_proj, "weight_scale"):
layer.q_proj.weight_scale.data = torch.cat(
[layer.q_proj.weight_scale,
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
else:
return
del layer.kv_a_proj_with_mqa.weight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
elif quant_method == "GGUFLinearMethod":
pass
elif quant_method == "AWQMarlinLinearMethod":
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
assert dtype == torch.int32
if q_lora_rank is not None:
layer.q_a_proj.qweight.data = torch.cat(
[layer.q_a_proj.qweight,
layer.kv_a_proj_with_mqa.qweight], dim=1)
layer.q_a_proj.scales.data = torch.cat(
[layer.q_a_proj.scales,
layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_a_proj.qzeros.data = torch.cat(
[layer.q_a_proj.qzeros,
layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
elif not is_mla:
layer.q_proj.weight.data = torch.cat(
[layer.q_proj.weight,
layer.kv_a_proj_with_mqa.weight], dim=1)
layer.q_proj.scales.data = torch.cat(
[layer.q_proj.scales,
layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_proj.qzeros.data = torch.cat(
[layer.q_proj.qzeros,
layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
else:
return
del layer.kv_a_proj_with_mqa.qweight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
for _, layer in self.model.named_modules():
if layer.__class__.__name__ in [
"DeepseekV2Attention", "DeepseekV2MLAAttention"
]:
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
quant_method = (
layer.kv_a_proj_with_mqa.scheme.__class__.__name__)
else:
quant_method = (
layer.kv_a_proj_with_mqa
.quant_method.__class__.__name__)
if quant_method not in opt_support_quant_method:
break
inject_layer(
layer, quant_method,
is_mla=(layer.__class__.__name__
== "DeepseekV2MLAAttention"))
# Check if all parameters have been loaded
all_params = set(params_dict.keys())
not_loaded = all_params - loaded_params
if not_loaded:
logger.warning(
"DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s",
len(not_loaded),
"\n".join(sorted(not_loaded)),
)
else:
logger.info(
"DeepSeekMTP weight loading: All %d parameters loaded successfully.",
len(all_params),
)
return loaded_params
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.