diff --git a/vllm_ascend/patch/worker/patch_deepseek_mtp.py b/vllm_ascend/patch/worker/patch_deepseek_mtp.py index f64ebbe..68ac359 100644 --- a/vllm_ascend/patch/worker/patch_deepseek_mtp.py +++ b/vllm_ascend/patch/worker/patch_deepseek_mtp.py @@ -1,16 +1,54 @@ +from typing import Optional + import torch import torch.nn as nn +import vllm from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.deepseek_mtp import \ - DeepSeekMultiTokenPredictorLayer +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictorLayer) from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import maybe_prefix +def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, +) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + # Patch this for aclgraph support, as the original operation introduced d2h sync, + # which breaks aclgraph + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0 +@support_torch_compile +class AscendDeepSeekMTP(DeepSeekMTP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + class SharedHead(nn.Module): def __init__( @@ -53,3 +91,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init +vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 2a11731..d8b25e8 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -11,7 +11,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) -from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -19,6 +18,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.patch.worker.patch_deepseek_mtp import \ + AscendDeepSeekMTP as DeepSeekMTP from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ TorchairDeepSeekMTP