support qwen3-next full_decode_only mode. (#3949)
### What this PR does / why we need it?
support qwen3-next full_decode_only mode.
bs=1, max_token=1024
| branch| tps| e2e time|
| --- | --- | --- |
|piecewise |3.06 | 8.15 |
|fulldecodeonly | 7.2 | 3.47 |
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -76,7 +76,8 @@ from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, reorder_batch_to_split_decodes_and_prefills)
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -107,7 +108,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
# yapf: disable
|
||||
@@ -2644,6 +2646,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
@@ -2659,6 +2662,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.seq_lens_np[:num_reqs] = seq_lens
|
||||
self.seq_lens_np[num_reqs:] = 0
|
||||
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
|
||||
self.device).to(torch.int32)
|
||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
|
||||
@@ -2715,12 +2726,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculative_config.method == "deepseek_mtp":
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
|
||||
1],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
block_table_tensor=block_table_tensor[:num_reqs],
|
||||
slot_mapping=slot_mapping,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=seq_lens)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
builder = attn_group.get_metadata_builder()
|
||||
attn_metadata_i = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
if isinstance(builder, AscendAttentionMetadataBuilder):
|
||||
attn_metadata_full_attention = builder.build_for_graph_capture(
|
||||
common_attn_metadata, attn_state, self.get_model())
|
||||
elif isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
|
||||
common_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
if "linear_attn" in layer_name:
|
||||
attn_metadata[
|
||||
layer_name] = attn_metadata_gdn_attention
|
||||
else:
|
||||
attn_metadata[
|
||||
layer_name] = attn_metadata_full_attention
|
||||
|
||||
return attn_metadata
|
||||
|
||||
@@ -2895,6 +2929,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_query_len=max_query_len,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
force_attention=force_attention,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
|
||||
Reference in New Issue
Block a user