[MTP][Bugfix] Fix GLM5-W8A8 precision issues caused by rotary quant MTP weights (#7139)
### What this PR does / why we need it?
When GLM5 target model uses rotary quant, the final hidden states passes
to MTP need to do an extra rotary.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Wangbingjie <wangbj1207@126.com>
Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
This commit is contained in:
@@ -452,3 +452,35 @@
|
||||
# https://github.com/vllm-project/vllm/pull/34880
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merges the PR.
|
||||
#
|
||||
# ** 21. File: worker/patch_deepseek_mtp.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.deepseek_v2.get_spec_layer_idx_from_weight_name` and
|
||||
# `vllm.model_executor.models.deepseek_mtp.get_spec_layer_idx_from_weight_name`
|
||||
# Why:
|
||||
# When GLM5 uses rotary quant in vllm-ascend, the MTP layer needs to load an extra weight
|
||||
# named `rot.weight`.
|
||||
# How:
|
||||
# If weight name starts with `rot`, return `layer_id + i` like other tensors in MTP layer.
|
||||
# Related PR (if no, explain why):
|
||||
# Rotary quant is a unique feature of vllm-ascend.
|
||||
# Future Plan:
|
||||
# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`.
|
||||
# 2. `vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer`
|
||||
# Why:
|
||||
# When GLM5 uses rotary quant in vllm-ascend, the `previous_hidden_states` does not .
|
||||
# How:
|
||||
# If the target model uses rotary quant, a new linear operation is added before `ehnorm`.
|
||||
# Related PR (if no, explain why):
|
||||
# Rotary quant is a unique feature of vllm-ascend.
|
||||
# Future Plan:
|
||||
# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`.
|
||||
# 3. `vllm.model_executor.models.deepseek_mtp.DeepSeekMTP._rewrite_spec_layer_name`
|
||||
# Why:
|
||||
# Rename `rot.weight` to match the format of weights in `DeepSeekMTP`.
|
||||
# How:
|
||||
# If the weight name is `rot`, rename it to `model.layers.{spec_layer}.rot.weight`.
|
||||
# Related PR (if no, explain why):
|
||||
# Rotary quant is a unique feature of vllm-ascend.
|
||||
# Future Plan:
|
||||
# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`.
|
||||
|
||||
@@ -44,3 +44,4 @@ import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
||||
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
|
||||
import vllm_ascend.patch.worker.patch_draft_quarot # noqa
|
||||
import vllm_ascend.patch.worker.patch_cudagraph # noqa
|
||||
import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa
|
||||
|
||||
63
vllm_ascend/patch/worker/patch_deepseek_mtp.py
Normal file
63
vllm_ascend/patch/worker/patch_deepseek_mtp.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP, DeepSeekMultiTokenPredictorLayer
|
||||
|
||||
MTP_ROT_WEIGHT_NAME = "rot.weight"
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config | DeepseekV3Config, weight_name: str) -> int | None:
|
||||
if hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0:
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx + i}.") or weight_name.startswith(MTP_ROT_WEIGHT_NAME):
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
|
||||
class AscendDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super().__init__(vllm_config, prefix)
|
||||
quant_description = getattr(vllm_config.quant_config, "quant_description", None)
|
||||
self.is_rot_used = quant_description.get("is_rot_used", False) if quant_description is not None else False
|
||||
self.target_model_type = vllm_config.speculative_config.target_model_config.hf_text_config.model_type
|
||||
if self.is_rot_used and self.target_model_type == "glm_moe_dsa":
|
||||
self.rot = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
if self.is_rot_used and self.target_model_type == "glm_moe_dsa":
|
||||
previous_hidden_states = self.rot(previous_hidden_states)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AscendDeepSeekMTP(DeepSeekMTP):
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
if name != MTP_ROT_WEIGHT_NAME:
|
||||
return super()._rewrite_spec_layer_name(spec_layer, name)
|
||||
else:
|
||||
return f"model.layers.{spec_layer}.rot.weight"
|
||||
|
||||
|
||||
vllm.model_executor.models.deepseek_v2.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name
|
||||
vllm.model_executor.models.deepseek_mtp.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name
|
||||
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer = AscendDeepSeekMultiTokenPredictorLayer
|
||||
vllm.model_executor.models.deepseek_mtp.DeepSeekMTP = AscendDeepSeekMTP
|
||||
Reference in New Issue
Block a user