[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)

### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.

And the rest of them are just bugfix.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test cases needed.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-12-10 20:11:09 +08:00
committed by GitHub
parent 0d8c0f1a24
commit 5b179c53f1
6 changed files with 120 additions and 78 deletions

View File

@@ -273,6 +273,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
key].decode.actual_seq_lengths_q
block_table = forward_context.attn_metadata[
key].decode.block_table
# TODO: This is a hack and should be fixed in the future.
if speculative_config.disable_padded_drafter_batch:
block_table = block_table[:len(actual_seq_lengths)]
seq_lens_list = seq_lens_list + [0] * (
len(actual_seq_lengths) - len(seq_lens_list))
else:
@@ -427,7 +430,7 @@ class GraphParams:
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
def set_graph_params(aclgraph_capture_sizes: list[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
@@ -456,7 +459,7 @@ def get_graph_params():
_mtp_graph_params: Optional[GraphParams] = None
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
def set_mtp_graph_params(aclgraph_capture_sizes: list[int]):
global _mtp_graph_params
if _mtp_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!")