Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -32,6 +32,7 @@ from .deepseek_v2 import (
|
||||
DeepseekV2MoE,
|
||||
get_spec_layer_idx_from_weight_name,
|
||||
)
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
@@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
weight_loader(param, loaded_weight)
|
||||
if not is_fusion_moe_shared_experts_layer:
|
||||
loaded_params.add(name)
|
||||
|
||||
# Validate that weights were loaded for each expected MTP layer.
|
||||
loaded_layers: set[int] = set()
|
||||
for param_name in loaded_params:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
|
||||
if spec_layer is not None:
|
||||
loaded_layers.add(spec_layer)
|
||||
for layer_idx in range(
|
||||
self.model.mtp_start_layer_idx,
|
||||
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
|
||||
):
|
||||
if layer_idx not in loaded_layers:
|
||||
raise ValueError(
|
||||
f"MTP speculative decoding layer {layer_idx} weights "
|
||||
f"missing from checkpoint. The checkpoint may have "
|
||||
f"been quantized without including the MTP layers. "
|
||||
f"Use a checkpoint that includes MTP layer weights, "
|
||||
f"or disable speculative decoding."
|
||||
)
|
||||
|
||||
# Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa
|
||||
# into a single GEMM, then monkey-patch forward to forward_opt.
|
||||
# Same logic as DeepseekV2ForCausalLM.load_weights.
|
||||
opt_support_quant_method = [
|
||||
"GGUFLinearMethod", "UnquantizedLinearMethod",
|
||||
"CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod",
|
||||
]
|
||||
|
||||
def inject_layer(layer, quant_method, is_mla):
|
||||
logger.info(
|
||||
"DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.",
|
||||
layer.__class__.__name__, quant_method, is_mla)
|
||||
q_lora_rank = getattr(layer, "q_lora_rank", None)
|
||||
if quant_method in ["UnquantizedLinearMethod",
|
||||
"CompressedTensorsW8A8Int8"]:
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.weight.data = torch.cat(
|
||||
[layer.q_a_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_a_proj, "weight_scale"):
|
||||
layer.q_a_proj.weight_scale.data = torch.cat(
|
||||
[layer.q_a_proj.weight_scale,
|
||||
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat(
|
||||
[layer.q_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=0)
|
||||
if hasattr(layer.q_proj, "weight_scale"):
|
||||
layer.q_proj.weight_scale.data = torch.cat(
|
||||
[layer.q_proj.weight_scale,
|
||||
layer.kv_a_proj_with_mqa.weight_scale], dim=0)
|
||||
del layer.kv_a_proj_with_mqa.weight_scale
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.weight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
elif quant_method == "GGUFLinearMethod":
|
||||
pass
|
||||
elif quant_method == "AWQMarlinLinearMethod":
|
||||
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
|
||||
assert dtype == torch.int32
|
||||
if q_lora_rank is not None:
|
||||
layer.q_a_proj.qweight.data = torch.cat(
|
||||
[layer.q_a_proj.qweight,
|
||||
layer.kv_a_proj_with_mqa.qweight], dim=1)
|
||||
layer.q_a_proj.scales.data = torch.cat(
|
||||
[layer.q_a_proj.scales,
|
||||
layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_a_proj.qzeros.data = torch.cat(
|
||||
[layer.q_a_proj.qzeros,
|
||||
layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
elif not is_mla:
|
||||
layer.q_proj.weight.data = torch.cat(
|
||||
[layer.q_proj.weight,
|
||||
layer.kv_a_proj_with_mqa.weight], dim=1)
|
||||
layer.q_proj.scales.data = torch.cat(
|
||||
[layer.q_proj.scales,
|
||||
layer.kv_a_proj_with_mqa.scales], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.scales
|
||||
layer.q_proj.qzeros.data = torch.cat(
|
||||
[layer.q_proj.qzeros,
|
||||
layer.kv_a_proj_with_mqa.qzeros], dim=1)
|
||||
del layer.kv_a_proj_with_mqa.qzeros
|
||||
else:
|
||||
return
|
||||
del layer.kv_a_proj_with_mqa.qweight
|
||||
del layer.kv_a_proj_with_mqa
|
||||
if is_mla:
|
||||
layer.mla_attn.forward = layer.mla_attn.forward_opt
|
||||
else:
|
||||
layer.forward = layer.forward_opt
|
||||
|
||||
for _, layer in self.model.named_modules():
|
||||
if layer.__class__.__name__ in [
|
||||
"DeepseekV2Attention", "DeepseekV2MLAAttention"
|
||||
]:
|
||||
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
|
||||
quant_method = (
|
||||
layer.kv_a_proj_with_mqa.scheme.__class__.__name__)
|
||||
else:
|
||||
quant_method = (
|
||||
layer.kv_a_proj_with_mqa
|
||||
.quant_method.__class__.__name__)
|
||||
if quant_method not in opt_support_quant_method:
|
||||
break
|
||||
inject_layer(
|
||||
layer, quant_method,
|
||||
is_mla=(layer.__class__.__name__
|
||||
== "DeepseekV2MLAAttention"))
|
||||
|
||||
# Check if all parameters have been loaded
|
||||
all_params = set(params_dict.keys())
|
||||
not_loaded = all_params - loaded_params
|
||||
if not_loaded:
|
||||
logger.warning(
|
||||
"DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s",
|
||||
len(not_loaded),
|
||||
"\n".join(sorted(not_loaded)),
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"DeepSeekMTP weight loading: All %d parameters loaded successfully.",
|
||||
len(all_params),
|
||||
)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
|
||||
Reference in New Issue
Block a user