From 46e62efd44cfccac5b07d0694dd0740efb404fb4 Mon Sep 17 00:00:00 2001 From: anon189Ty Date: Fri, 17 Oct 2025 18:14:49 +0800 Subject: [PATCH] [Feat]mtp aclgraph support (#3244) ### What this PR does / why we need it? Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR is use to allow MTP to be captured in ACLGraph mode. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty --- vllm_ascend/models/deepseek_mtp.py | 2 ++ vllm_ascend/spec_decode/eagle_proposer.py | 6 ++++-- vllm_ascend/spec_decode/interface.py | 6 ++++-- vllm_ascend/spec_decode/mtp_proposer.py | 13 +++++++++---- vllm_ascend/spec_decode/ngram_proposer.py | 5 ++++- vllm_ascend/worker/model_runner_v1.py | 4 +++- 6 files changed, 26 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 318d806..7fbec3b 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -23,6 +23,7 @@ import torch import torch.nn as nn from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) from vllm.model_executor.layers.layernorm import RMSNorm @@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): return logits +@support_torch_compile class CustomDeepSeekMTP(DeepSeekMTP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 7959a71..ec3751b 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger @@ -114,7 +114,9 @@ class EagleProposer(Proposer): with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None): + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) with set_ascend_forward_context(None, diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 0efe93d..3f0a36b 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -2,7 +2,7 @@ import enum from typing import Optional import torch -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -33,7 +33,9 @@ class Proposer: with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None): + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): """Called by dummy_run in modle_runner""" raise NotImplementedError diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 3df4f69..a3baabf 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -5,8 +5,8 @@ import torch.nn as nn import torchair from torchair import patch_for_hcom from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config import (CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( @@ -109,7 +109,9 @@ class MtpProposer(Proposer): with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp=None) -> None: + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None) -> None: if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, @@ -151,7 +153,9 @@ class MtpProposer(Proposer): reserved_mc2_mask=self.runner.reserved_mc2_mask, moe_comm_type=moe_comm_type, in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor): if is_running_torchair: assert attn_metadata is not None torch._dynamo.mark_static(input_ids) @@ -442,6 +446,7 @@ class MtpProposer(Proposer): reserved_mc2_mask=self.runner.reserved_mc2_mask, moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): with ProfileExecuteDuration().capture_async('mtp_forward'): diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 9999f1f..34b5b95 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,4 +1,5 @@ import torch +from vllm.config import CUDAGraphMode from vllm.v1.spec_decode.ngram_proposer import \ NgramProposer as VllmNgramProposer @@ -23,7 +24,9 @@ class NgramProposer(VllmNgramProposer, Proposer): with_prefill=None, skip_attn=None, num_reqs=None, - num_tokens_across_dp=None): + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): pass def generate_token_ids(self, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 731b93f..96bf679 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2479,7 +2479,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill=with_prefill, skip_attn=True, num_reqs=num_reqs, - num_tokens_across_dp=num_tokens_across_dp) + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor) if need_dummy_logits: dummy_compute_logits(hidden_states) if self.in_profile_run and self.dynamic_eplb: