[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -297,30 +297,11 @@ class AscendAttentionMetadataBuilder:
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
if common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
||||
seq_lens = torch.cat([
|
||||
seq_lens,
|
||||
torch.tensor([padded_num_tokens
|
||||
]).to(seq_lens.device).to(seq_lens.dtype)
|
||||
])
|
||||
block_table_padding = torch.zeros(
|
||||
(padded_num_tokens, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
||||
query_start_loc_cpu = torch.cat([
|
||||
query_start_loc_cpu,
|
||||
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
|
||||
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
||||
])
|
||||
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
is_causal_pooling = None
|
||||
|
||||
Reference in New Issue
Block a user