[MTP][V1] Adapt mtp with graph mode in v1. (#1023)

Adapts deepseek mtp with torch air graph mode in v1.

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-06-09 22:21:42 +08:00
committed by GitHub
parent 5ac4872f5e
commit cd2f14a1b3
4 changed files with 87 additions and 24 deletions

View File

@@ -203,8 +203,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Set up speculative decoding.
self.use_spec_decode = False
self.spec_attn_mask = None
if self.speculative_config:
self.use_spec_decode = True
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to("npu")
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
@@ -779,10 +784,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
num_valid_tokens[i] = num_tokens - \
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
@@ -838,11 +846,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
out=self.slot_mapping_np[:total_num_scheduled_tokens])
ascend_config = get_ascend_config()
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
attn_state = AscendAttentionState.SpecDecoding
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
@@ -873,7 +886,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
seq_lens = self.seq_lens[:num_reqs]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
with_prefill = attn_state != AscendAttentionState.DecodeOnly
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
@@ -883,14 +898,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Add graph_pad_size here
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
and not with_prefill):
batch_size = len(seq_lens)
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
batch_size)
graph_pad_size = padded_batch_size - batch_size
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
if self.vllm_config.model_config.use_mla:

View File

@@ -4,7 +4,8 @@ from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
@@ -199,6 +200,8 @@ class MtpProposer:
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
process_weights_after_loading(self.model, draft_model_config,
target_device)
# TODO Using torch instead of triton may result in poor performance