[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -32,7 +32,6 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_mtp_graph_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
@@ -214,8 +213,6 @@ class MtpProposer(Proposer):
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
):
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
set_mtp_graph_params(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
@@ -254,9 +251,10 @@ class MtpProposer(Proposer):
|
||||
query_start_loc_cpu=self.runner.
|
||||
query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.runner.seq_lens_cpu,
|
||||
seq_lens=self.runner.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.runner.seq_lens[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
num_input_tokens=num_tokens,
|
||||
max_query_len=self.num_speculative_tokens + 1,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
@@ -289,7 +287,7 @@ class MtpProposer(Proposer):
|
||||
positions = self.positions[:num_tokens]
|
||||
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||
for i in range(self.num_speculative_tokens):
|
||||
if i > 0:
|
||||
if i > 0 and not skip_attn and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
aclgraph_runtime_mode = CUDAGraphMode.NONE
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
@@ -316,7 +314,7 @@ class MtpProposer(Proposer):
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
not forward_context.capturing:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context, num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
@@ -514,6 +512,7 @@ class MtpProposer(Proposer):
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_actual_reqs = len(num_draft_tokens)
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
@@ -522,8 +521,11 @@ class MtpProposer(Proposer):
|
||||
dtype=torch.int32)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_actual_reqs
|
||||
+ 1]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
|
||||
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
new_query_len_per_req = query_start_loc_cpu[
|
||||
@@ -587,6 +589,7 @@ class MtpProposer(Proposer):
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
@@ -704,8 +707,8 @@ class MtpProposer(Proposer):
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
if self.runner.use_aclgraph and num_scheduled_tokens <= self.cudagraph_batch_sizes[
|
||||
-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_scheduled_tokens)
|
||||
elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
@@ -797,7 +800,7 @@ class MtpProposer(Proposer):
|
||||
hidden_states=hidden_states)
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
update_mla_attn_params(
|
||||
self.update_stream, forward_context,
|
||||
num_input_tokens,
|
||||
@@ -1109,9 +1112,10 @@ class MtpProposer(Proposer):
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
|
||||
Reference in New Issue
Block a user