[FEAT] Support DeepSeek-V3.2 with FULL_DECODE_ONLY mode (#4706)
### What this PR does / why we need it?
The first commit support `FULL_DECODE_ONLY`:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
The second commit take MTP into account:
- Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for
slicing slots and positions, ensuring fixed tensor shapes.
- Implement padding logic for `query_start_loc` in `NPUModelRunner` to
support uniform decode in full graph mode, aligning with GPU runner
behavior.
- Adjust MLA cosine cache allocation to occur independently of graph
mode and switch to using device-resident sequence lengths for attention
metadata.
- Remove redundant slicing of hidden states and outputs in
`AscendSFAImpl` and optimize `sin`/`cos` cache updates.
And the rest of them are just bugfix.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test cases needed.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -124,6 +124,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
# yapf: disable
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_graph_params,
|
||||
set_mtp_graph_params,
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
@@ -406,8 +407,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
if self.vllm_config.model_config.use_mla and \
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
# NOTE: This will have some extra memory allocated, is it OK?
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos = torch.ones(self.max_num_reqs *
|
||||
self.decode_token_per_req,
|
||||
@@ -1843,6 +1844,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
|
||||
slot_mapping_size = (total_num_scheduled_tokens
|
||||
if self.pcp_size == 1 else
|
||||
total_num_scheduled_tokens * self.pcp_size -
|
||||
@@ -1864,7 +1866,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
|
||||
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
|
||||
if self.pcp_size > 1:
|
||||
slot_mapping_for_pcp = blk_table.slot_mapping[:
|
||||
@@ -1884,14 +1885,48 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
slot_mapping_size]
|
||||
slot_mapping_for_pcp[:long_seq_metadata.
|
||||
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
|
||||
slot_mapping = slot_mapping_for_pcp
|
||||
blk_table.slot_mapping[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
|
||||
slot_mapping_for_pcp
|
||||
slot_mapping = blk_table.slot_mapping
|
||||
|
||||
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
|
||||
# has been split to multiple parts, and there are 3 parts that is related to this
|
||||
# `num_reqs`, we'll take `query_start_loc` as an example:
|
||||
# 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||||
# 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the
|
||||
# following simplified `dispatch` logic here, we try to minimize the impact
|
||||
# 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
|
||||
uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \
|
||||
and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs)
|
||||
|
||||
# TODO: We should make this official ASAP. Also note that if we pad here,
|
||||
# the builders won’t need to add any extra padding.
|
||||
if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||
uniform_decode:
|
||||
num_reqs_padded = num_input_tokens // self.uniform_decode_query_len
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
if pad_size > 0:
|
||||
last_query_loc = self.query_start_loc[num_reqs]
|
||||
|
||||
steps = torch.arange(1,
|
||||
pad_size + 1,
|
||||
device=self.device,
|
||||
dtype=self.query_start_loc.dtype)
|
||||
fill_values = last_query_loc + (
|
||||
steps * self.uniform_decode_query_len)
|
||||
|
||||
self.query_start_loc[num_reqs + 1:num_reqs_padded +
|
||||
1] = fill_values
|
||||
# So we are trying to simulate the behavior of GPUModelRunner's
|
||||
# prepare_inputs for uniform decode mode by padding query_start_loc
|
||||
num_reqs = num_reqs_padded
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=slot_mapping_size,
|
||||
num_input_tokens=num_input_tokens,
|
||||
@@ -2876,6 +2911,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
seq_lens = max_query_len
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
@@ -2906,21 +2943,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||
] for _ in range(num_tokens)]
|
||||
long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
|
||||
# QUESTION: Why do we separately set query_start_loc for spec in the first place?
|
||||
# While in _prepare_inputs we don't?
|
||||
if self.speculative_config:
|
||||
query_start_loc = torch.tensor(
|
||||
self.query_start_loc[:num_reqs + 1] = torch.tensor(
|
||||
[0] + self.actual_seq_lengths_q[:num_reqs],
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
num_input_tokens=num_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
@@ -3210,7 +3248,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
dummy_compute_logits=dummy_drafter_compute_logits)
|
||||
dummy_compute_logits=dummy_drafter_compute_logits,
|
||||
skip_attn=not force_attention)
|
||||
if self.in_profile_run and self.dynamic_eplb:
|
||||
self.model.clear_all_moe_loads()
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
@@ -3373,7 +3412,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# wrap the model with full graph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
|
||||
self.model = ACLGraphWrapper(self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
@@ -4092,6 +4130,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.aclgraph_batch_sizes = (capture_sizes
|
||||
if capture_sizes is not None else [])
|
||||
|
||||
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
||||
# we set the graph params right before initializing the keys.
|
||||
set_graph_params(self.aclgraph_batch_sizes)
|
||||
if self.speculative_config:
|
||||
set_mtp_graph_params(self.aclgraph_batch_sizes)
|
||||
|
||||
self.aclgraph_dispatcher.initialize_cudagraph_keys(
|
||||
self.compilation_config.cudagraph_mode,
|
||||
self.uniform_decode_query_len)
|
||||
|
||||
Reference in New Issue
Block a user