Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -6,9 +6,11 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
from .utils import maybe_prefix
@@ -40,9 +43,11 @@ class SharedHead(nn.Module):
return self.norm(hidden_states)
@support_torch_compile
class Step3p5AMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str,
) -> None:
@@ -51,8 +56,15 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
quant_config = vllm_config.quant_config
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.eh_proj = ReplicatedLinear(
config.hidden_size * 2,
config.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.eh_proj",
)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
@@ -64,9 +76,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
@@ -74,7 +89,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
)
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
hidden_states = hidden_states + residual
return hidden_states
@@ -92,8 +108,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
@@ -112,14 +128,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
@@ -131,7 +146,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
)
return logits
@@ -139,7 +154,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
return self.embed_tokens(input_ids)
class Step3p5MTP(nn.Module):
class Step3p5MTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
@@ -257,12 +272,15 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
and self.config.num_nextn_predict_layers > 0
)
name = "model.embed_tokens.weight"
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
@@ -280,6 +298,31 @@ class Step3p5MTP(nn.Module):
and getattr(param, "requires_grad", False) is False
}
params_need_to_load -= optional_params
# Some attention quantization metadata may be absent in checkpoints
# depending on export path / runtime backend.
missing_params = params_need_to_load - loaded_params
skip_missing_suffixes = (
".attn.k_scale",
".attn.q_scale",
".attn.v_scale",
".attn.q_zero_point",
".attn.k_zero_point",
".attn.v_zero_point",
)
skipped_missing_params = sorted(
name for name in missing_params if name.endswith(skip_missing_suffixes)
)
if skipped_missing_params:
preview = tuple(skipped_missing_params[:10])
logger.warning_once(
"Step3p5MTP load_weights: skip %d missing optional attn quant "
"params with suffixes %s (showing first %d): %s",
len(skipped_missing_params),
skip_missing_suffixes,
len(preview),
preview,
)
params_need_to_load -= set(skipped_missing_params)
if params_need_to_load != loaded_params:
missing_params = list(params_need_to_load - loaded_params)
param_name_example = missing_params[0]