### What this PR does / why we need it?
When GLM5 target model uses rotary quant, the final hidden states passes
to MTP need to do an extra rotary.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Wangbingjie <wangbj1207@126.com>
Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
64 lines
3.1 KiB
Python
64 lines
3.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import vllm
|
|
from transformers import DeepseekV2Config, DeepseekV3Config
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP, DeepSeekMultiTokenPredictorLayer
|
|
|
|
MTP_ROT_WEIGHT_NAME = "rot.weight"
|
|
|
|
|
|
def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config | DeepseekV3Config, weight_name: str) -> int | None:
|
|
if hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0:
|
|
layer_idx = config.num_hidden_layers
|
|
for i in range(config.num_nextn_predict_layers):
|
|
if weight_name.startswith(f"model.layers.{layer_idx + i}.") or weight_name.startswith(MTP_ROT_WEIGHT_NAME):
|
|
return layer_idx + i
|
|
return None
|
|
|
|
|
|
class AscendDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
|
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
|
super().__init__(vllm_config, prefix)
|
|
quant_description = getattr(vllm_config.quant_config, "quant_description", None)
|
|
self.is_rot_used = quant_description.get("is_rot_used", False) if quant_description is not None else False
|
|
self.target_model_type = vllm_config.speculative_config.target_model_config.hf_text_config.model_type
|
|
if self.is_rot_used and self.target_model_type == "glm_moe_dsa":
|
|
self.rot = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
spec_step_index: int = 0,
|
|
) -> torch.Tensor:
|
|
assert inputs_embeds is not None
|
|
# masking inputs at position 0, as not needed by MTP
|
|
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
|
inputs_embeds = self.enorm(inputs_embeds)
|
|
if self.is_rot_used and self.target_model_type == "glm_moe_dsa":
|
|
previous_hidden_states = self.rot(previous_hidden_states)
|
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
|
|
|
hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
|
|
|
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class AscendDeepSeekMTP(DeepSeekMTP):
|
|
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
|
if name != MTP_ROT_WEIGHT_NAME:
|
|
return super()._rewrite_spec_layer_name(spec_layer, name)
|
|
else:
|
|
return f"model.layers.{spec_layer}.rot.weight"
|
|
|
|
|
|
vllm.model_executor.models.deepseek_v2.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name
|
|
vllm.model_executor.models.deepseek_mtp.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name
|
|
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer = AscendDeepSeekMultiTokenPredictorLayer
|
|
vllm.model_executor.models.deepseek_mtp.DeepSeekMTP = AscendDeepSeekMTP
|