[Feat]Make full graph mode compalible with MTP (#3276)

### What this PR does / why we need it?
Make the Full Graph mode can run with MTP.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-17 20:19:56 +08:00
committed by GitHub
parent 46e62efd44
commit 248ee7fa11
7 changed files with 103 additions and 44 deletions

View File

@@ -345,6 +345,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.speculative_config.method, self.vllm_config,
self.device, self)
self.rejection_sampler = AscendRejectionSampler()
self.actual_seq_lengths_q = list(
range(self.decode_token_per_req, self.max_num_tokens + 1,
self.decode_token_per_req))
# Persistent batch.
self.input_ids = torch.zeros(self.max_num_tokens,
@@ -366,13 +369,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.model_config.use_mla and \
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos = torch.ones(self.max_num_reqs,
self.cos = torch.ones(self.max_num_reqs *
self.decode_token_per_req,
1,
1,
rope_dim,
dtype=self.dtype,
device=self.device)
self.sin = torch.zeros(self.max_num_reqs,
self.sin = torch.zeros(self.max_num_reqs *
self.decode_token_per_req,
1,
1,
rope_dim,
@@ -1554,7 +1559,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
maybe_padded_num_tokens,
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
maybe_padded_num_tokens)
@@ -2255,7 +2261,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_table_tensor = self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc=torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
device=self.device,
dtype=torch.int32),
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
1],
seq_lens_cpu=self.seq_lens_cpu,
@@ -2275,12 +2284,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cos=self.cos,
sin=self.sin,
)
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and \
self.speculative_config.method == "deepseek_mtp":
attn_state = AscendAttentionState.SpecDecoding
for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
attn_metadata_i = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.DecodeOnly,
self.get_model())
common_attn_metadata, attn_state, self.get_model())
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@@ -2301,7 +2313,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0])
positions.shape[0],
self.speculative_config)
else:
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
@@ -3388,23 +3401,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
CUDAGraphMode.FULL_DECODE_ONLY)
logger.warning(msg)
# check that if spec-decode + decode full-graphs is supported
if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.uniform_decode_query_len > 1 and min_ag_support.value
< AttentionCGSupport.UNIFORM_BATCH.value):
msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported"
f" with spec-decode for attention backend "
f"{min_ag_builder_name} (support: {min_ag_support})")
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
else:
msg += "; setting cudagraph_mode=NONE"
aclgraph_mode = self.compilation_config.cudagraph_mode = \
CUDAGraphMode.NONE
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
if aclgraph_mode.has_full_cudagraphs() \