### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/attention/attention_mask.py` |
| `vllm_ascend/attention/attention_v1.py` |
| `vllm_ascend/attention/context_parallel/attention_cp.py` |
| `vllm_ascend/attention/context_parallel/common_cp.py` |
| `vllm_ascend/attention/context_parallel/mla_cp.py` |
| `vllm_ascend/attention/utils.py` |
| `vllm_ascend/batch_invariant.py` |
| `vllm_ascend/device/device_op.py` |
| `vllm_ascend/device_allocator/camem.py` |
| `vllm_ascend/envs.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import ClassVar, List, Optional, Tuple, Type
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -29,32 +29,49 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, CrossAttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendMetadataForDecode, AscendMetadataForPrefill)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
enable_cp, split_decodes_and_prefills,
|
||||
using_paged_attention)
|
||||
from vllm_ascend.attention.context_parallel.common_cp import AscendMetadataForDecode, AscendMetadataForPrefill
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata,
|
||||
enable_cp,
|
||||
split_decodes_and_prefills,
|
||||
using_paged_attention,
|
||||
)
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
get_draft_graph_params,
|
||||
get_graph_params,
|
||||
update_draft_graph_params_workspaces,
|
||||
update_graph_params_workspaces,
|
||||
)
|
||||
from vllm_ascend.device.device_op import DeviceOperator
|
||||
from vllm_ascend.ops.flashcomm2_oshard_manager import flashcomm2_oshard_manager
|
||||
from vllm_ascend.utils import vllm_version_is, weak_ref_tensors
|
||||
|
||||
# isort: off
|
||||
if vllm_version_is('0.13.0'):
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder)
|
||||
if vllm_version_is("0.13.0"):
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, AttentionMetadataBuilder
|
||||
from vllm.attention.backends.abstract import ( # type: ignore
|
||||
AttentionBackend, AttentionImpl, AttentionLayer, AttentionType)
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.backends.registry import ( # type: ignore
|
||||
AttentionBackendEnum, register_backend)
|
||||
AttentionBackendEnum,
|
||||
register_backend,
|
||||
)
|
||||
else:
|
||||
from vllm.v1.attention.backend import ( # type: ignore
|
||||
AttentionBackend, AttentionCGSupport, AttentionImpl, AttentionLayer,
|
||||
AttentionType, AttentionMetadataBuilder)
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
AttentionMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import ( # type: ignore
|
||||
AttentionBackendEnum, register_backend)
|
||||
AttentionBackendEnum,
|
||||
register_backend,
|
||||
)
|
||||
# isort: on
|
||||
|
||||
# default max value of sliding window size
|
||||
@@ -73,18 +90,18 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
def get_impl_cls() -> type["AscendAttentionBackendImpl"]:
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||
AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPImpl
|
||||
|
||||
return AscendAttentionCPImpl
|
||||
return AscendAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import \
|
||||
AscendAttentionCPMetadataBuilder
|
||||
from vllm_ascend.attention.context_parallel.attention_cp import AscendAttentionCPMetadataBuilder
|
||||
|
||||
return AscendAttentionCPMetadataBuilder
|
||||
return AscendAttentionMetadataBuilder
|
||||
|
||||
@@ -94,13 +111,13 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
dst_kv_cache: List[torch.Tensor],
|
||||
src_kv_cache: list[torch.Tensor],
|
||||
dst_kv_cache: list[torch.Tensor],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
||||
@@ -108,14 +125,12 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
src_indices = src_to_dst[:, 0]
|
||||
dst_indices = src_to_dst[:, 1]
|
||||
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
||||
dst_key_cache.device)
|
||||
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(dst_key_cache.device)
|
||||
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(dst_key_cache.device)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
kv_caches: list[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
src_indices = src_to_dists[:, 0]
|
||||
@@ -148,8 +163,9 @@ class AscendMetadata:
|
||||
Contains attention masks, token counts, sequence lengths and KV cache
|
||||
related properties for attention computation.
|
||||
"""
|
||||
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
attn_mask: torch.Tensor | None = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
@@ -168,12 +184,12 @@ class AscendMetadata:
|
||||
# should simplified these parameters once attention schema in vLLM-Ascend
|
||||
# is unified.
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_list: List[int] = None # type: ignore
|
||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||
seq_lens_list: list[int] = None # type: ignore
|
||||
actual_seq_lengths_q: list[int] = None # type: ignore
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
# Maximum query length in the batch (None for decoding).
|
||||
max_query_len: Optional[int] = None
|
||||
max_query_len: int | None = None
|
||||
|
||||
# ********************** KV Cache Related Properties ********************* #
|
||||
# Block addresses per sequence (Seq id -> list of physical block).
|
||||
@@ -187,9 +203,9 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
# pcp
|
||||
prefill: Optional[AscendMetadataForPrefill] = None
|
||||
prefill: AscendMetadataForPrefill | None = None
|
||||
# dcp
|
||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||
decode_meta: AscendMetadataForDecode | None = None
|
||||
|
||||
causal: bool = True
|
||||
# runner_type in model_config.
|
||||
@@ -198,7 +214,7 @@ class AscendMetadata:
|
||||
reshape_cache_event: torch.npu.Event = None
|
||||
|
||||
# sliding window attention mask
|
||||
swa_mask: Optional[torch.Tensor] = None
|
||||
swa_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
@@ -208,6 +224,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
Handles attention mask generation and metadata preparation for
|
||||
Ascend FlashAttention backend.
|
||||
"""
|
||||
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
@@ -226,17 +243,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
AscendAttentionBackend.get_supported_block_size()[0])
|
||||
self.model_config.max_model_len, AscendAttentionBackend.get_supported_block_size()[0]
|
||||
)
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
assert self.decode_threshold <= 16, (
|
||||
f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
)
|
||||
|
||||
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
@@ -254,8 +273,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.ALWAYS
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
def reorder_batch(self, input_batch, scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(
|
||||
@@ -266,12 +284,11 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
) -> AscendMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.decode_threshold
|
||||
)
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
@@ -283,19 +300,17 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config)
|
||||
attn_mask = self.attn_mask_builder.get_attention_mask(self.model_config)
|
||||
|
||||
swa_mask = None
|
||||
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
|
||||
is_swa = hasattr(self.model_config.hf_text_config, "sliding_window")
|
||||
if self.model_config is not None and is_swa:
|
||||
swa_mask = self.attn_mask_builder.get_swa_mask(
|
||||
self.model_config.dtype,
|
||||
self.model_config.hf_text_config.sliding_window)
|
||||
self.model_config.dtype, self.model_config.hf_text_config.sliding_window
|
||||
)
|
||||
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device, non_blocking=True)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -313,7 +328,8 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
causal=common_attn_metadata.causal,
|
||||
model_runner_type=self.model_config.runner_type)
|
||||
model_runner_type=self.model_config.runner_type,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -321,9 +337,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
):
|
||||
|
||||
if attn_state in (AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.ChunkedPrefill):
|
||||
if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill):
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@@ -338,19 +352,18 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
|
||||
|
||||
class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
@@ -362,9 +375,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32, device="npu")
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
@@ -372,18 +383,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_producer = (
|
||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
super().process_weights_after_loading(act_dtype)
|
||||
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
|
||||
flashcomm2_oshard_manager.post_process_after_loading()
|
||||
|
||||
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
def full_graph_fia(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
||||
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
forward_context = get_forward_context()
|
||||
@@ -427,12 +444,22 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
||||
actual_seq_lengths_kv, actual_seq_lengths_q, self.num_kv_heads,
|
||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
(
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(key),
|
||||
weak_ref_tensors(value),
|
||||
weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.attn_mask),
|
||||
block_size,
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths_q,
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse),
|
||||
)
|
||||
)
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
@@ -463,7 +490,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
):
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
@@ -481,7 +508,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
out=output,
|
||||
)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
@@ -491,17 +519,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
weak_ref_tensors(self.value_cache),
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
))
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
weak_ref_tensors(self.value_cache),
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
)
|
||||
)
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
@@ -514,53 +544,54 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
workspace=workspace,
|
||||
)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output
|
||||
|
||||
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata):
|
||||
|
||||
def _get_fia_params(self, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
|
||||
if self.attn_type == AttentionType.ENCODER_DECODER:
|
||||
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens,
|
||||
dim=0).tolist()
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
actual_seq_lengths_kv = torch.cumsum(attn_metadata.seq_lens, dim=0).tolist()
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# chunked prefill.
|
||||
else:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
num_block, block_size, -1
|
||||
)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
return key, value, block_size, block_table, actual_seq_lengths_kv
|
||||
|
||||
def _forward_fia_slidingwindow(self, query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
def _forward_fia_slidingwindow(self, query: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor):
|
||||
batch_size = attn_metadata.seq_lens.shape[0]
|
||||
block_size = 128
|
||||
query = query.view(batch_size, 1, self.num_heads * self.head_size)
|
||||
@@ -583,34 +614,41 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
actual_seq_lengths=[1] * len(attn_metadata.seq_lens),
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens)
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
)
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
return output
|
||||
|
||||
def forward_fused_infer_attention(self, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
def forward_fused_infer_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
# we inherit ForwardContext in model runner v2, when enable model
|
||||
# runner v2, there is not capturing attribute in forward_context,
|
||||
# just use getattr to avoid attribute error.
|
||||
if getattr(forward_context, "capturing", False):
|
||||
attn_output, num_tokens = self.full_graph_fia(
|
||||
query, key, value, attn_metadata, output)
|
||||
attn_output, num_tokens = self.full_graph_fia(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and self.sliding_window is not None
|
||||
and attn_metadata.seq_lens.shape[0] == query.size(0)):
|
||||
return self._forward_fia_slidingwindow(query, attn_metadata,
|
||||
output)
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv \
|
||||
= self._get_fia_params(key, value, attn_metadata)
|
||||
if (
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and self.sliding_window is not None
|
||||
and attn_metadata.seq_lens.shape[0] == query.size(0)
|
||||
):
|
||||
return self._forward_fia_slidingwindow(query, attn_metadata, output)
|
||||
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
|
||||
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
|
||||
query = query[:num_tokens]
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache and self.attn_type != AttentionType.ENCODER_DECODER:
|
||||
if (
|
||||
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
|
||||
and self.attn_type != AttentionType.ENCODER_DECODER
|
||||
):
|
||||
key = key[:num_tokens]
|
||||
value = value[:num_tokens]
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
@@ -630,8 +668,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
@@ -639,26 +676,32 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
return self.full_graph_pa(query, attn_metadata, output)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
)
|
||||
return output
|
||||
|
||||
def _forward_encoder_attention(self, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
_: torch.Tensor) -> torch.Tensor:
|
||||
def _forward_encoder_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
_: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
|
||||
if attn_metadata.causal:
|
||||
@@ -692,26 +735,23 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
):
|
||||
|
||||
if len(kv_cache) > 1:
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
encoder_decoder = (self.attn_type == AttentionType.ENCODER_DECODER)
|
||||
encoder_decoder = self.attn_type == AttentionType.ENCODER_DECODER
|
||||
DeviceOperator.reshape_and_cache(
|
||||
key=key[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else key,
|
||||
value=value[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else value,
|
||||
key=key[: attn_metadata.num_actual_tokens] if not encoder_decoder else key,
|
||||
value=value[: attn_metadata.num_actual_tokens] if not encoder_decoder else value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_mapping=slots[:attn_metadata.num_actual_tokens]
|
||||
if not encoder_decoder else slots)
|
||||
slot_mapping=slots[: attn_metadata.num_actual_tokens] if not encoder_decoder else slots,
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
@@ -721,18 +761,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
num_tokens = query.shape[0]
|
||||
if (attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and using_paged_attention(num_tokens, self.vllm_config)
|
||||
and self.sliding_window is None):
|
||||
if (
|
||||
attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
and using_paged_attention(num_tokens, self.vllm_config)
|
||||
and self.sliding_window is None
|
||||
):
|
||||
output = self.forward_paged_attention(query, attn_metadata, output)
|
||||
else:
|
||||
output = self.forward_fused_infer_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output = self.forward_fused_infer_attention(query, key, value, attn_metadata, output)
|
||||
|
||||
return output
|
||||
|
||||
@@ -742,11 +783,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Ascend attention.
|
||||
Args:
|
||||
@@ -762,23 +803,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for AscendAttentionBackendImpl")
|
||||
raise NotImplementedError("fused output quantization is not yet supported for AscendAttentionBackendImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
if key is not None and value is not None:
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata)
|
||||
# pooling model branch
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,
|
||||
output)
|
||||
output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user