[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)

### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

And the rest of them are just bugfix.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test cases needed.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-12-10 20:11:09 +08:00
committed by GitHub
parent 0d8c0f1a24
commit 5b179c53f1
6 changed files with 120 additions and 78 deletions

View File

@@ -32,7 +32,6 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_mtp_graph_params,
update_mla_attn_params)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -214,8 +213,6 @@ class MtpProposer(Proposer):
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):
self.update_stream: torch.npu.Stream = torch.npu.Stream()
set_mtp_graph_params(
self.vllm_config.compilation_config.cudagraph_capture_sizes)
self.model = ACLGraphWrapper(self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
@@ -254,9 +251,10 @@ class MtpProposer(Proposer):
query_start_loc_cpu=self.runner.
query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.runner.seq_lens_cpu,
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
seq_lens=self.runner.seq_lens[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
num_input_tokens=num_tokens,
max_query_len=self.num_speculative_tokens + 1,
num_computed_tokens_cpu=num_computed_tokens_cpu,
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
@@ -289,7 +287,7 @@ class MtpProposer(Proposer):
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
for i in range(self.num_speculative_tokens):
if i > 0:
if i > 0 and not skip_attn and aclgraph_runtime_mode == CUDAGraphMode.FULL:
aclgraph_runtime_mode = CUDAGraphMode.NONE
with set_ascend_forward_context(
attn_metadata,
@@ -316,7 +314,7 @@ class MtpProposer(Proposer):
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
update_mla_attn_params(
self.update_stream, forward_context, num_tokens,
self.vllm_config.speculative_config)
@@ -514,6 +512,7 @@ class MtpProposer(Proposer):
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_actual_reqs = len(num_draft_tokens)
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
@@ -522,8 +521,11 @@ class MtpProposer(Proposer):
dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_actual_reqs
+ 1]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = query_start_loc_cpu[
@@ -587,6 +589,7 @@ class MtpProposer(Proposer):
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
max_query_len=new_query_len_per_req.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
@@ -704,8 +707,8 @@ class MtpProposer(Proposer):
assert self.runner is not None
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
if self.runner.use_aclgraph and num_scheduled_tokens <= self.cudagraph_batch_sizes[
-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
@@ -797,7 +800,7 @@ class MtpProposer(Proposer):
hidden_states=hidden_states)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
@@ -1109,9 +1112,10 @@ class MtpProposer(Proposer):
spec_common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
num_input_tokens=common_attn_metadata.num_input_tokens,
max_query_len=new_query_len_per_req.max().item(),
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
block_table_tensor=common_attn_metadata.block_table_tensor,