diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index b4d4148d..4a1b0a16 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -452,3 +452,35 @@ # https://github.com/vllm-project/vllm/pull/34880 # Future Plan: # Remove this patch when vLLM merges the PR. +# +# ** 21. File: worker/patch_deepseek_mtp.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_v2.get_spec_layer_idx_from_weight_name` and +# `vllm.model_executor.models.deepseek_mtp.get_spec_layer_idx_from_weight_name` +# Why: +# When GLM5 uses rotary quant in vllm-ascend, the MTP layer needs to load an extra weight +# named `rot.weight`. +# How: +# If weight name starts with `rot`, return `layer_id + i` like other tensors in MTP layer. +# Related PR (if no, explain why): +# Rotary quant is a unique feature of vllm-ascend. +# Future Plan: +# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# 2. `vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer` +# Why: +# When GLM5 uses rotary quant in vllm-ascend, the `previous_hidden_states` does not . +# How: +# If the target model uses rotary quant, a new linear operation is added before `ehnorm`. +# Related PR (if no, explain why): +# Rotary quant is a unique feature of vllm-ascend. +# Future Plan: +# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# 3. `vllm.model_executor.models.deepseek_mtp.DeepSeekMTP._rewrite_spec_layer_name` +# Why: +# Rename `rot.weight` to match the format of weights in `DeepSeekMTP`. +# How: +# If the weight name is `rot`, rename it to `model.layers.{spec_layer}.rot.weight`. +# Related PR (if no, explain why): +# Rotary quant is a unique feature of vllm-ascend. +# Future Plan: +# Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 1664c84a..a847cac2 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -44,3 +44,4 @@ import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa import vllm_ascend.patch.worker.patch_draft_quarot # noqa import vllm_ascend.patch.worker.patch_cudagraph # noqa +import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_deepseek_mtp.py b/vllm_ascend/patch/worker/patch_deepseek_mtp.py new file mode 100644 index 00000000..bc147d74 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_deepseek_mtp.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +import vllm +from transformers import DeepseekV2Config, DeepseekV3Config +from vllm.config import VllmConfig +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP, DeepSeekMultiTokenPredictorLayer + +MTP_ROT_WEIGHT_NAME = "rot.weight" + + +def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config | DeepseekV3Config, weight_name: str) -> int | None: + if hasattr(config, "num_nextn_predict_layers") and config.num_nextn_predict_layers > 0: + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx + i}.") or weight_name.startswith(MTP_ROT_WEIGHT_NAME): + return layer_idx + i + return None + + +class AscendDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + super().__init__(vllm_config, prefix) + quant_description = getattr(vllm_config.quant_config, "quant_description", None) + self.is_rot_used = quant_description.get("is_rot_used", False) if quant_description is not None else False + self.target_model_type = vllm_config.speculative_config.target_model_config.hf_text_config.model_type + if self.is_rot_used and self.target_model_type == "glm_moe_dsa": + self.rot = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + if self.is_rot_used and self.target_model_type == "glm_moe_dsa": + previous_hidden_states = self.rot(previous_hidden_states) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, hidden_states=hidden_states, residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class AscendDeepSeekMTP(DeepSeekMTP): + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + if name != MTP_ROT_WEIGHT_NAME: + return super()._rewrite_spec_layer_name(spec_layer, name) + else: + return f"model.layers.{spec_layer}.rot.weight" + + +vllm.model_executor.models.deepseek_v2.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name +vllm.model_executor.models.deepseek_mtp.get_spec_layer_idx_from_weight_name = get_spec_layer_idx_from_weight_name +vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer = AscendDeepSeekMultiTokenPredictorLayer +vllm.model_executor.models.deepseek_mtp.DeepSeekMTP = AscendDeepSeekMTP