[Refactor] remove some metadata variables in attention_v1. (#5160)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The metadata data class contains an excessive number of variables. We
will inherit the metadata of the community and simultaneously remove
some variables that are no longer needed at present.
Todo:
1. remove attn_state partly.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -235,8 +235,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
|
||||||
query_lens=query_lens,
|
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_list=seq_lens.tolist(),
|
seq_lens_list=seq_lens.tolist(),
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
split_decodes_and_prefills,
|
enable_cp, split_decodes_and_prefills,
|
||||||
using_paged_attention)
|
using_paged_attention)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
@@ -52,9 +52,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||||
prefill_config = get_current_vllm_config().parallel_config
|
if enable_cp():
|
||||||
if (prefill_config.prefill_context_parallel_size > 1
|
|
||||||
or prefill_config.decode_context_parallel_size > 1):
|
|
||||||
from vllm_ascend.attention.attention_cp import \
|
from vllm_ascend.attention.attention_cp import \
|
||||||
AscendAttentionCPImpl
|
AscendAttentionCPImpl
|
||||||
return AscendAttentionCPImpl
|
return AscendAttentionCPImpl
|
||||||
@@ -62,9 +60,7 @@ class AscendAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||||
prefill_config = get_current_vllm_config().parallel_config
|
if enable_cp():
|
||||||
if (prefill_config.prefill_context_parallel_size > 1
|
|
||||||
or prefill_config.decode_context_parallel_size > 1):
|
|
||||||
from vllm_ascend.attention.attention_cp import \
|
from vllm_ascend.attention.attention_cp import \
|
||||||
AscendAttentionCPMetadataBuilder
|
AscendAttentionCPMetadataBuilder
|
||||||
return AscendAttentionCPMetadataBuilder
|
return AscendAttentionCPMetadataBuilder
|
||||||
@@ -191,10 +187,8 @@ class AscendMetadata:
|
|||||||
seq_lens: torch.Tensor = None
|
seq_lens: torch.Tensor = None
|
||||||
seq_lens_list: List[int] = None # type: ignore
|
seq_lens_list: List[int] = None # type: ignore
|
||||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||||
query_start_loc_list: List[int] = None # type: ignore
|
|
||||||
|
|
||||||
query_start_loc: torch.Tensor = None
|
query_start_loc: torch.Tensor = None
|
||||||
query_lens: torch.Tensor = None
|
|
||||||
# Maximum query length in the batch (None for decoding).
|
# Maximum query length in the batch (None for decoding).
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
@@ -214,9 +208,9 @@ class AscendMetadata:
|
|||||||
# dcp
|
# dcp
|
||||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||||
|
|
||||||
# Whether is the pooling model with causal attention,
|
causal: bool = True
|
||||||
# used to guide the attention computation for pooling models.
|
# runner_type in model_config.
|
||||||
is_causal_pooling: Optional[bool] = None
|
model_runner_type: str = ""
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionMetadataBuilder:
|
class AscendAttentionMetadataBuilder:
|
||||||
@@ -276,11 +270,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
assert num_decodes + num_prefills == num_reqs
|
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
|
||||||
|
|
||||||
block_table = common_attn_metadata.block_table_tensor
|
block_table = common_attn_metadata.block_table_tensor
|
||||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
||||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
|
||||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
@@ -297,10 +288,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
is_causal_pooling = None
|
|
||||||
if self.model_config.runner_type == "pooling":
|
|
||||||
is_causal_pooling = common_attn_metadata.causal if hasattr(
|
|
||||||
common_attn_metadata, 'causal') else True
|
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@@ -308,8 +295,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
block_tables=block_table,
|
block_tables=block_table,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
|
||||||
query_lens=query_lens,
|
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_list=seq_lens.tolist(),
|
seq_lens_list=seq_lens.tolist(),
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
@@ -319,7 +304,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
is_causal_pooling=is_causal_pooling)
|
causal=common_attn_metadata.causal,
|
||||||
|
model_runner_type=self.model_config.runner_type)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build_for_graph_capture(
|
def build_for_graph_capture(
|
||||||
@@ -384,9 +370,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||||
= self._get_fia_params(key, value, attn_metadata)
|
= self._get_fia_params(key, value, attn_metadata)
|
||||||
|
|
||||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
query_start_loc = attn_metadata.query_start_loc_list
|
actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q
|
||||||
# Prepare tensors for attention output
|
# Prepare tensors for attention output
|
||||||
# TODO: Refactor this to step-level instead of layer-level
|
# TODO: Refactor this to step-level instead of layer-level
|
||||||
|
|
||||||
@@ -402,7 +388,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
actual_seq_lengths=query_start_loc,
|
actual_seq_lengths=actual_seq_lengths_q,
|
||||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
num_key_value_heads=self.num_kv_heads,
|
num_key_value_heads=self.num_kv_heads,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@@ -422,7 +408,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||||
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
||||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
|
||||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||||
weak_ref_tensors(softmax_lse)))
|
weak_ref_tensors(softmax_lse)))
|
||||||
|
|
||||||
@@ -435,7 +421,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
actual_seq_lengths=query_start_loc,
|
actual_seq_lengths=actual_seq_lengths_q,
|
||||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
num_key_value_heads=self.num_kv_heads,
|
num_key_value_heads=self.num_kv_heads,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
@@ -518,10 +504,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
block_size = 128
|
block_size = 128
|
||||||
block_table = None
|
block_table = None
|
||||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||||
elif attn_metadata.attn_state == \
|
elif attn_metadata.attn_state == \
|
||||||
AscendAttentionState.PrefillCacheHit:
|
AscendAttentionState.PrefillCacheHit:
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
batch_size = attn_metadata.seq_lens.shape[0]
|
||||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||||
key = self.key_cache.view( # type: ignore
|
key = self.key_cache.view( # type: ignore
|
||||||
@@ -644,9 +630,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
_: torch.Tensor) -> torch.Tensor:
|
_: torch.Tensor) -> torch.Tensor:
|
||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
assert attn_metadata.is_causal_pooling is not None
|
|
||||||
|
|
||||||
if attn_metadata.is_causal_pooling:
|
if attn_metadata.causal:
|
||||||
# use sparse_mode 3 in causal scenario
|
# use sparse_mode 3 in causal scenario
|
||||||
return torch_npu.npu_fusion_attention(
|
return torch_npu.npu_fusion_attention(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -768,7 +753,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||||
attn_metadata)
|
attn_metadata)
|
||||||
# pooling model branch
|
# pooling model branch
|
||||||
if isinstance(attn_metadata.is_causal_pooling, bool):
|
if attn_metadata.model_runner_type == "pooling":
|
||||||
attn_output = self._forward_encoder_attention(
|
attn_output = self._forward_encoder_attention(
|
||||||
query, key, value, attn_metadata, output)
|
query, key, value, attn_metadata, output)
|
||||||
output[:num_tokens] = attn_output[:num_tokens]
|
output[:num_tokens] = attn_output[:num_tokens]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from vllm_ascend import envs
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
enable_cp,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
trans_rope_weight, transdata,
|
trans_rope_weight, transdata,
|
||||||
@@ -57,8 +58,7 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls():
|
def get_builder_cls():
|
||||||
prefill_config = get_current_vllm_config().parallel_config
|
if enable_cp():
|
||||||
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
|
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
|
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
|
||||||
return AscendMlaCPMetadataBuilder
|
return AscendMlaCPMetadataBuilder
|
||||||
return AscendMLAMetadataBuilder
|
return AscendMLAMetadataBuilder
|
||||||
@@ -70,8 +70,7 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
||||||
prefill_config = get_current_vllm_config().parallel_config
|
if enable_cp():
|
||||||
if prefill_config.prefill_context_parallel_size > 1 or prefill_config.decode_context_parallel_size > 1:
|
|
||||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||||
return AscendMlaCPImpl
|
return AscendMlaCPImpl
|
||||||
return AscendMLAImpl
|
return AscendMLAImpl
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
is_v1_kv_transfer_group)
|
is_v1_kv_transfer_group)
|
||||||
@@ -26,6 +27,13 @@ def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool:
|
|||||||
return runtime_shape in get_ascend_config().pa_shape_list
|
return runtime_shape in get_ascend_config().pa_shape_list
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def enable_cp():
|
||||||
|
prefill_config = get_current_vllm_config().parallel_config
|
||||||
|
return prefill_config.prefill_context_parallel_size > 1 \
|
||||||
|
or prefill_config.decode_context_parallel_size > 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
# class AscendCommonLongSequenceMetadata:
|
# class AscendCommonLongSequenceMetadata:
|
||||||
class AscendPrefillContextParallelMetadata:
|
class AscendPrefillContextParallelMetadata:
|
||||||
@@ -66,7 +74,7 @@ class AscendCommonAttentionMetadata:
|
|||||||
Per-batch attention metadata, shared across layers and backends.
|
Per-batch attention metadata, shared across layers and backends.
|
||||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||||
|
|
||||||
For many of the tensors we keep both GPU and CPU versions.
|
For many of the tensors we keep both NPU and CPU versions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
@@ -109,8 +117,6 @@ class AscendCommonAttentionMetadata:
|
|||||||
|
|
||||||
attn_state: Any = None
|
attn_state: Any = None
|
||||||
|
|
||||||
is_only_prefill: bool = False
|
|
||||||
|
|
||||||
graph_pad_size: int = -1
|
graph_pad_size: int = -1
|
||||||
|
|
||||||
# num_input_tokens refers to total number of tokens including
|
# num_input_tokens refers to total number of tokens including
|
||||||
@@ -120,6 +126,8 @@ class AscendCommonAttentionMetadata:
|
|||||||
prefill_context_parallel_metadata: Optional[
|
prefill_context_parallel_metadata: Optional[
|
||||||
AscendPrefillContextParallelMetadata] = None
|
AscendPrefillContextParallelMetadata] = None
|
||||||
|
|
||||||
|
causal: bool = True
|
||||||
|
|
||||||
# TODO: Remove it when vLLM no longer uses this function.
|
# TODO: Remove it when vLLM no longer uses this function.
|
||||||
def unpadded(self, num_actual_tokens: int,
|
def unpadded(self, num_actual_tokens: int,
|
||||||
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
num_actual_reqs: int) -> "AscendCommonAttentionMetadata":
|
||||||
@@ -137,12 +145,12 @@ class AscendCommonAttentionMetadata:
|
|||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
block_table_tensor=self.block_table_tensor[:num_actual_reqs],
|
||||||
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
slot_mapping=self.slot_mapping[:num_actual_tokens],
|
||||||
|
causal=self.causal,
|
||||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||||
positions=self.positions[:num_actual_tokens],
|
positions=self.positions[:num_actual_tokens],
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
is_only_prefill=self.is_only_prefill,
|
|
||||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||||
num_input_tokens=num_actual_tokens,
|
num_input_tokens=num_actual_tokens,
|
||||||
prefill_context_parallel_metadata=self.
|
prefill_context_parallel_metadata=self.
|
||||||
|
|||||||
@@ -271,8 +271,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
|||||||
attn_output, softmax_lse) = param
|
attn_output, softmax_lse) = param
|
||||||
|
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||||
query_start_loc = forward_context.attn_metadata[
|
actual_seq_lengths_q = forward_context.attn_metadata[
|
||||||
key].query_start_loc_list
|
key].actual_seq_lengths_q
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu.npu_fused_infer_attention_score.out(
|
torch_npu.npu_fused_infer_attention_score.out(
|
||||||
query=query,
|
query=query,
|
||||||
@@ -282,7 +282,7 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
|||||||
atten_mask=attn_mask,
|
atten_mask=attn_mask,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
actual_seq_lengths=query_start_loc,
|
actual_seq_lengths=actual_seq_lengths_q,
|
||||||
actual_seq_lengths_kv=seq_lens,
|
actual_seq_lengths_kv=seq_lens,
|
||||||
num_key_value_heads=num_kv_heads,
|
num_key_value_heads=num_kv_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
|||||||
@@ -384,12 +384,11 @@ class EagleProposer(Proposer):
|
|||||||
attn_metadata.num_actual_tokens = batch_size
|
attn_metadata.num_actual_tokens = batch_size
|
||||||
attn_metadata.max_query_len = 1
|
attn_metadata.max_query_len = 1
|
||||||
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
|
attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1]
|
||||||
attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[
|
|
||||||
1:].tolist()
|
|
||||||
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
|
attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size
|
||||||
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
|
attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens
|
||||||
|
|
||||||
attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)]
|
attn_metadata.actual_seq_lengths_q = attn_metadata.query_start_loc[
|
||||||
|
1:].tolist()
|
||||||
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
|
attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist()
|
||||||
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
for now_speculative in range(
|
for now_speculative in range(
|
||||||
|
|||||||
@@ -1042,7 +1042,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
prefill_context_parallel_metadata=long_seq_metadata,
|
prefill_context_parallel_metadata=long_seq_metadata,
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ def build_attn_metadata(
|
|||||||
| None = None,
|
| None = None,
|
||||||
spec_attn_mask: torch.Tensor | None = None,
|
spec_attn_mask: torch.Tensor | None = None,
|
||||||
attn_state: Any | None = None,
|
attn_state: Any | None = None,
|
||||||
is_only_prefill: bool = False,
|
|
||||||
graph_pad_size: int = -1,
|
graph_pad_size: int = -1,
|
||||||
num_input_tokens: int = 0,
|
num_input_tokens: int = 0,
|
||||||
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
|
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
|
||||||
@@ -78,7 +77,6 @@ def build_attn_metadata(
|
|||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
spec_attn_mask=spec_attn_mask,
|
spec_attn_mask=spec_attn_mask,
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
is_only_prefill=is_only_prefill,
|
|
||||||
graph_pad_size=graph_pad_size,
|
graph_pad_size=graph_pad_size,
|
||||||
num_input_tokens=num_input_tokens,
|
num_input_tokens=num_input_tokens,
|
||||||
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
|
||||||
|
|||||||
@@ -247,7 +247,9 @@ class XliteWrapper:
|
|||||||
if not with_prefill or self.full_mode:
|
if not with_prefill or self.full_mode:
|
||||||
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
|
batch = attn_metadata.num_prefills + attn_metadata.num_decodes
|
||||||
seq_lens = attn_metadata.seq_lens[:batch]
|
seq_lens = attn_metadata.seq_lens[:batch]
|
||||||
query_lens = attn_metadata.query_lens[:batch]
|
query_lens = attn_metadata.query_start_loc_cpu[
|
||||||
|
1:] - attn_metadata.query_start_loc_cpu[:-1]
|
||||||
|
query_lens = query_lens[:batch]
|
||||||
cached_lens = seq_lens - query_lens
|
cached_lens = seq_lens - query_lens
|
||||||
|
|
||||||
xlite_attn_metadata = ModelAttnMeta()
|
xlite_attn_metadata = ModelAttnMeta()
|
||||||
|
|||||||
Reference in New Issue
Block a user