Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -6,9 +6,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -18,6 +20,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .step3p5 import Step3p5DecoderLayer, get_spec_layer_idx_from_weight_name
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@@ -40,9 +43,11 @@ class SharedHead(nn.Module):
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
@@ -51,8 +56,15 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.eh_proj = ReplicatedLinear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=f"{prefix}.eh_proj",
|
||||
)
|
||||
self.lm_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = Step3p5DecoderLayer(
|
||||
vllm_config,
|
||||
prefix=f"{prefix}.mtp_block",
|
||||
@@ -64,9 +76,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
embed_tokens: VocabParallelEmbedding | None = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
if inputs_embeds is None:
|
||||
assert embed_tokens is not None
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
@@ -74,7 +89,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
|
||||
)
|
||||
|
||||
hidden_states = self.mtp_block(positions=positions, hidden_states=hidden_states)
|
||||
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -92,8 +108,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
self.layers = torch.nn.ModuleDict(
|
||||
{
|
||||
str(idx): Step3p5AMultiTokenPredictorLayer(
|
||||
vllm_config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(
|
||||
self.mtp_start_layer_idx,
|
||||
@@ -112,14 +128,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
self.embed_tokens,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
@@ -131,7 +146,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
|
||||
logits = self.logits_processor(
|
||||
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
|
||||
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
|
||||
)
|
||||
return logits
|
||||
|
||||
@@ -139,7 +154,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
|
||||
class Step3p5MTP(nn.Module):
|
||||
class Step3p5MTP(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
@@ -257,12 +272,15 @@ class Step3p5MTP(nn.Module):
|
||||
name = name.replace(".transformer.", ".")
|
||||
if "shared_head" in name:
|
||||
name = name.replace("shared_head.output", "shared_head.head")
|
||||
name = name.replace("shared_head", "lm_head")
|
||||
if "embed_tokens" in name:
|
||||
assert (
|
||||
hasattr(self.config, "num_nextn_predict_layers")
|
||||
and self.config.num_nextn_predict_layers > 0
|
||||
)
|
||||
name = "model.embed_tokens.weight"
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
@@ -280,6 +298,31 @@ class Step3p5MTP(nn.Module):
|
||||
and getattr(param, "requires_grad", False) is False
|
||||
}
|
||||
params_need_to_load -= optional_params
|
||||
# Some attention quantization metadata may be absent in checkpoints
|
||||
# depending on export path / runtime backend.
|
||||
missing_params = params_need_to_load - loaded_params
|
||||
skip_missing_suffixes = (
|
||||
".attn.k_scale",
|
||||
".attn.q_scale",
|
||||
".attn.v_scale",
|
||||
".attn.q_zero_point",
|
||||
".attn.k_zero_point",
|
||||
".attn.v_zero_point",
|
||||
)
|
||||
skipped_missing_params = sorted(
|
||||
name for name in missing_params if name.endswith(skip_missing_suffixes)
|
||||
)
|
||||
if skipped_missing_params:
|
||||
preview = tuple(skipped_missing_params[:10])
|
||||
logger.warning_once(
|
||||
"Step3p5MTP load_weights: skip %d missing optional attn quant "
|
||||
"params with suffixes %s (showing first %d): %s",
|
||||
len(skipped_missing_params),
|
||||
skip_missing_suffixes,
|
||||
len(preview),
|
||||
preview,
|
||||
)
|
||||
params_need_to_load -= set(skipped_missing_params)
|
||||
if params_need_to_load != loaded_params:
|
||||
missing_params = list(params_need_to_load - loaded_params)
|
||||
param_name_example = missing_params[0]
|
||||
|
||||
Reference in New Issue
Block a user