[Refactor] remove some metadata variables in attention_v1. (#5160)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
enable_cp, split_decodes_and_prefills,
|
||||
using_paged_attention)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
@@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
prefill_config = get_current_vllm_config().parallel_config
|
||||
if (prefill_config.prefill_context_parallel_size > 1
|
||||
or prefill_config.decode_context_parallel_size > 1):
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.attention_cp import \
|
||||
AscendAttentionCPImpl
|
||||
return AscendAttentionCPImpl
|
||||
@@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
prefill_config = get_current_vllm_config().parallel_config
|
||||
if (prefill_config.prefill_context_parallel_size > 1
|
||||
or prefill_config.decode_context_parallel_size > 1):
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.attention_cp import \
|
||||
AscendAttentionCPMetadataBuilder
|
||||
return AscendAttentionCPMetadataBuilder
|
||||
@@ -191,10 +187,8 @@ class AscendMetadata:
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_list: List[int] = None # type: ignore
|
||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||
query_start_loc_list: List[int] = None # type: ignore
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
# Maximum query length in the batch (None for decoding).
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
@@ -214,9 +208,9 @@ class AscendMetadata:
|
||||
# dcp
|
||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||
|
||||
# Whether is the pooling model with causal attention,
|
||||
# used to guide the attention computation for pooling models.
|
||||
is_causal_pooling: Optional[bool] = None
|
||||
causal: bool = True
|
||||
# runner_type in model_config.
|
||||
model_runner_type: str = ""
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
@@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder:
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
@@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder:
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
is_causal_pooling = None
|
||||
if self.model_config.runner_type == "pooling":
|
||||
is_causal_pooling = common_attn_metadata.causal if hasattr(
|
||||
common_attn_metadata, 'causal') else True
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
@@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_state=attn_state,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
is_causal_pooling=is_causal_pooling)
|
||||
causal=common_attn_metadata.causal,
|
||||
model_runner_type=self.model_config.runner_type)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
|
||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
graph_params = get_graph_params()
|
||||
query_start_loc = attn_metadata.query_start_loc_list
|
||||
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
|
||||
# Prepare tensors for attention output
|
||||
# TODO: Refactor this to step-level instead of layer-level
|
||||
|
||||
@@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
@@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
|
||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
|
||||
@@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
@@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
@@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
_: torch.Tensor) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.is_causal_pooling is not None
|
||||
|
||||
if attn_metadata.is_causal_pooling:
|
||||
if attn_metadata.causal:
|
||||
# use sparse_mode 3 in causal scenario
|
||||
return torch_npu.npu_fusion_attention(
|
||||
query=query,
|
||||
@@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
# pooling model branch
|
||||
if isinstance(attn_metadata.is_causal_pooling, bool):
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
|
||||
Reference in New Issue
Block a user