[Refactor] remove some metadata variables in attention_v1. (#5160)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
@@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
||||
return runtime_shape in get_ascend_config().pa_shape_list
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def enable_cp():
|
||||
prefill_config = get_current_vllm_config().parallel_config
|
||||
return prefill_config.prefill_context_parallel_size > 1 \
|
||||
or prefill_config.decode_context_parallel_size > 1
|
||||
|
||||
|
||||
@dataclass
|
||||
# class AscendCommonLongSequenceMetadata:
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
@@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
For many of the tensors we keep both NPU and CPU versions.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
@@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
is_only_prefill: bool = False
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
# num_input_tokens refers to total number of tokens including
|
||||
@@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
|
||||
prefill_context_parallel_metadata: Optional[
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
causal: bool = True
|
||||
|
||||
# TODO: Remove it when vLLM no longer uses this function.
|
||||
def unpadded(self, num_actual_tokens: int,
|
||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||
@@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata:
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||
causal=self.causal,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
is_only_prefill=self.is_only_prefill,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
prefill_context_parallel_metadata=self.
|
||||
|
||||
Reference in New Issue
Block a user