[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -273,6 +273,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
key].decode.actual_seq_lengths_q
|
||||
block_table = forward_context.attn_metadata[
|
||||
key].decode.block_table
|
||||
# TODO: This is a hack and should be fixed in the future.
|
||||
if speculative_config.disable_padded_drafter_batch:
|
||||
block_table = block_table[:len(actual_seq_lengths)]
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
len(actual_seq_lengths) - len(seq_lens_list))
|
||||
else:
|
||||
@@ -427,7 +430,7 @@ class GraphParams:
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
@@ -456,7 +459,7 @@ def get_graph_params():
|
||||
_mtp_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
def set_mtp_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
global _mtp_graph_params
|
||||
if _mtp_graph_params is not None:
|
||||
raise ValueError("MTPGraph parameters have already been set!")
|
||||
|
||||
Reference in New Issue
Block a user