[0.11.0][MTP][Aclgraph] Fix the support aclgraph with MTP (#3912)

### What this PR does / why we need it?
Fix 2 breaks of aclgraph with MTP:
1. deepseekmtp in vllm 0.11.0 does not support aclgraph and lack the
`support_torch_compile` decorator
2. There is a d2h synchornization in the original forward of mtp
predictor. The fix pr in vllm
https://github.com/vllm-project/vllm/pull/27643

As we'll fix it in vllm main, this fix pr is only needed in branch
v0.11.0-dev

The profling shows that MTP replays in aclgraph now:
<img width="1612" height="1866" alt="a7d7f04155df4ed454b7eb20a92b2e2a"
src="https://github.com/user-attachments/assets/eaa4b9ff-aeb0-416d-964f-5a06e497f155"
/>

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-03 14:25:37 +08:00
committed by GitHub
parent 8a7154001e
commit 7cc6208029
2 changed files with 43 additions and 3 deletions

View File

@@ -1,16 +1,54 @@
from typing import Optional
import torch
import torch.nn as nn
import vllm
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.deepseek_mtp import \
DeepSeekMultiTokenPredictorLayer
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictorLayer)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
# Patch this for aclgraph support, as the original operation introduced d2h sync,
# which breaks aclgraph
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0
@support_torch_compile
class AscendDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
class SharedHead(nn.Module):
def __init__(
@@ -53,3 +91,4 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None:
DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init
vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.forward = forward

View File

@@ -11,7 +11,6 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -19,6 +18,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.patch.worker.patch_deepseek_mtp import \
AscendDeepSeekMTP as DeepSeekMTP
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP