[Misc] Clean up uesless code in attention (#1933)
Before do attention module refactor, we can do some code cleanup to make
the next step easier.
What this PR does:
1. remove uesless `common_prefix_len` for attention builder
2. remove uesless `is_only_prefill` and `num_input_tokens` in attention
metadata.
3. remove `CommonAttentionMetadata` and ues `query_start_loc` instead,
`CommonAttentionMetadata` is over designed and uesless
4. update the attention backend input parameters to keep the same as
vLLM.
5. Rename attention name to the same style with `ASCEND` prefix
- vLLM version: v0.9.2
- vLLM main:
107111a859
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -31,27 +31,13 @@ if TYPE_CHECKING:
|
||||
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "VLLM_ASCEND_MLA"
|
||||
return "ASCEND_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
@@ -368,11 +354,10 @@ class AscendMLAMetadataBuilder:
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
common_prefix_len: Optional[int] = None,
|
||||
graph_pad_size: int = -1,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False,
|
||||
query_start_loc: torch.Tensor = None,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
@@ -394,7 +379,6 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
@@ -403,7 +387,6 @@ class AscendMLAMetadataBuilder:
|
||||
tokens_start = self._num_decode_tokens
|
||||
max_query_len = query_lens[tokens_start:].max().item()
|
||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
@@ -539,7 +522,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
|
||||
Reference in New Issue
Block a user