[Feat]mtp aclgraph support (#3244)

### What this PR does / why we need it?
Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR
is use to allow MTP to be captured in ACLGraph mode.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-17 18:14:49 +08:00
committed by GitHub
parent 1b424fb7f1
commit 46e62efd44
6 changed files with 26 additions and 10 deletions

View File

@@ -23,6 +23,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
return logits
@support_torch_compile
class CustomDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):