[Misc] Clean up uesless code in attention (#1933)
Before do attention module refactor, we can do some code cleanup to make
the next step easier.
What this PR does:
1. remove uesless `common_prefix_len` for attention builder
2. remove uesless `is_only_prefill` and `num_input_tokens` in attention
metadata.
3. remove `CommonAttentionMetadata` and ues `query_start_loc` instead,
`CommonAttentionMetadata` is over designed and uesless
4. update the attention backend input parameters to keep the same as
vLLM.
5. Rename attention name to the same style with `ASCEND` prefix
- vLLM version: v0.9.2
- vLLM main:
107111a859
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -75,8 +75,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLAMetadata,
|
||||
CommonAttentionMetadata)
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
@@ -694,15 +693,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -1049,27 +1043,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
extra_builder_kwargs[
|
||||
"query_start_loc"] = self.query_start_loc[:num_reqs + 1]
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
else:
|
||||
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=None,
|
||||
**extra_builder_kwargs,
|
||||
)
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
|
||||
# Prepare input_ids
|
||||
token_indices = (positions_np +
|
||||
|
||||
Reference in New Issue
Block a user