[2/N][Refactor] Refactor V1 attention for better extensibility (#1995)
### What this PR does / why we need it?
Refactor V1 Attention for better extensibility (prepared for torchair
attention refactor).
**Main changes:**
- Move different kinds of foward into their method respectively, e.g.,
`_forward_prefill_no_cache()`, `_forward_prefill_cache_hit()`,
`_forward_decode_only()`, `_forward_v1_style()`.
### Does this PR introduce _any_ user-facing change?
No.
- vLLM version: v0.10.0
- vLLM main:
14a5d903ab
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -120,7 +120,7 @@ class AscendAttentionState(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
|
|
||||||
# **************************** Basic Properties ****************************
|
# **************************** Basic Properties ************************** #
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
@@ -138,7 +138,7 @@ class AscendMetadata:
|
|||||||
# Maximum query length in the batch (None for decoding).
|
# Maximum query length in the batch (None for decoding).
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
# ********************** KV Cache Related Properties ***********************
|
# ********************** KV Cache Related Properties ********************* #
|
||||||
# Block addresses per sequence (Seq id -> list of physical block).
|
# Block addresses per sequence (Seq id -> list of physical block).
|
||||||
# (batch_size, max_blocks_per_seq)
|
# (batch_size, max_blocks_per_seq)
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
@@ -150,6 +150,7 @@ class AscendMetadata:
|
|||||||
# (num_tokens,)
|
# (num_tokens,)
|
||||||
slot_mapping: torch.Tensor = None
|
slot_mapping: torch.Tensor = None
|
||||||
|
|
||||||
|
# *************************** Other Properties *************************** #
|
||||||
enable_dbo_across_dp: bool = False
|
enable_dbo_across_dp: bool = False
|
||||||
is_only_prefill: bool = False
|
is_only_prefill: bool = False
|
||||||
|
|
||||||
@@ -245,6 +246,144 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
|
|
||||||
|
def _forward_prefill_no_cache(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
num_tokens=0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert attn_metadata is not None
|
||||||
|
assert attn_metadata.attn_mask is not None
|
||||||
|
|
||||||
|
mask = attn_metadata.attn_mask
|
||||||
|
|
||||||
|
if is_310p():
|
||||||
|
# align q k v output tensors
|
||||||
|
query = aligned_16(query)
|
||||||
|
key = aligned_16(key)
|
||||||
|
value = aligned_16(value)
|
||||||
|
output = aligned_16(output)
|
||||||
|
# do reformat in case of broadcasted tensors
|
||||||
|
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
||||||
|
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
||||||
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
torch_npu._npu_flash_attention(query=query,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
mask=mask,
|
||||||
|
seq_len=attn_metadata.seq_lens,
|
||||||
|
scale_value=self.scale,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
out=output)
|
||||||
|
assert output is not None
|
||||||
|
return output[:num_tokens, :, :]
|
||||||
|
|
||||||
|
def _forward_prefill_cache_hit(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert attn_metadata is not None
|
||||||
|
assert attn_metadata.attn_mask is not None
|
||||||
|
|
||||||
|
compress_mask = attn_metadata.attn_mask
|
||||||
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||||
|
|
||||||
|
torch_npu._npu_flash_attention_qlens(
|
||||||
|
query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
block_table=block_table,
|
||||||
|
mask=compress_mask,
|
||||||
|
seq_len=attn_metadata.query_lens,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
out=output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _forward_decode_only(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if is_310p():
|
||||||
|
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||||
|
attn_metadata.seq_lens = \
|
||||||
|
attn_metadata.seq_lens.to(device=query.device)
|
||||||
|
|
||||||
|
torch_npu._npu_paged_attention(query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
out=output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _forward_v1_style(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Use chunked prefill for head size 192 scenario, like deepseek
|
||||||
|
# paged_attention_splitfuse maybe crash at such scenario.
|
||||||
|
# TODO: vanilla path will be removed after the kernel support
|
||||||
|
# head_size 192 scenario.
|
||||||
|
if self.head_size == 192:
|
||||||
|
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
||||||
|
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
||||||
|
cu_seqlen_q = torch.tensor(cu_seqlen_q, device=query.device)
|
||||||
|
cu_seqlen_k = torch.tensor(cu_seqlen_k, device=query.device)
|
||||||
|
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
||||||
|
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
||||||
|
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
||||||
|
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
||||||
|
vanilla_chunked_prefill(output, query, self.key_cache,
|
||||||
|
self.value_cache,
|
||||||
|
attn_metadata.block_tables, cu_seqlen_q,
|
||||||
|
cu_seqlen_k, max_seqlen_q, max_seqlen_k,
|
||||||
|
self.scale, None, True)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Use paged attention.
|
||||||
|
assert attn_metadata is not None
|
||||||
|
assert attn_metadata.attn_mask is not None
|
||||||
|
|
||||||
|
if is_310p():
|
||||||
|
# Do reformat in case of broadcasted tensors.
|
||||||
|
attn_metadata.attn_mask = \
|
||||||
|
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
||||||
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
attn_metadata.seq_lens = \
|
||||||
|
attn_metadata.seq_lens.to(device=query.device)
|
||||||
|
|
||||||
|
torch_npu._npu_paged_attention_splitfuse(
|
||||||
|
query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
mask=attn_metadata.attn_mask,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
seq_len=attn_metadata.query_lens,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
out=output)
|
||||||
|
return output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@@ -325,109 +464,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
# V0-Style scheduler situation.
|
# V0-Style scheduler situation.
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
assert attn_metadata is not None
|
output = self._forward_prefill_no_cache(
|
||||||
assert attn_metadata.attn_mask is not None
|
query, key, value, attn_metadata, output, num_tokens)
|
||||||
mask = attn_metadata.attn_mask
|
elif attn_metadata.attn_state == \
|
||||||
if is_310p():
|
AscendAttentionState.PrefillCacheHit:
|
||||||
# align q k v output tensors
|
output = self._forward_prefill_cache_hit(
|
||||||
query = aligned_16(query)
|
query, attn_metadata, output)
|
||||||
key = aligned_16(key)
|
|
||||||
value = aligned_16(value)
|
|
||||||
output = aligned_16(output)
|
|
||||||
|
|
||||||
# do reformat in case of broadcasted tensors
|
|
||||||
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
|
|
||||||
mask = torch_npu.npu_format_cast(mask.contiguous(),
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
torch_npu._npu_flash_attention(query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seq_len=attn_metadata.seq_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
out=output)
|
|
||||||
output = output[:num_tokens, :, :]
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
|
||||||
assert attn_metadata is not None
|
|
||||||
assert attn_metadata.attn_mask is not None
|
|
||||||
compress_mask = attn_metadata.attn_mask
|
|
||||||
batch_size = attn_metadata.query_lens.shape[0]
|
|
||||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
|
||||||
torch_npu._npu_flash_attention_qlens(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
block_table=block_table,
|
|
||||||
mask=compress_mask,
|
|
||||||
seq_len=attn_metadata.query_lens,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
out=output)
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
if is_310p():
|
output = self._forward_decode_only(query, attn_metadata,
|
||||||
# # seq_lens_tensor needs to be transferred to the device for 310P
|
output)
|
||||||
attn_metadata.seq_lens = \
|
|
||||||
attn_metadata.seq_lens.to(device=query.device)
|
|
||||||
torch_npu._npu_paged_attention(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output)
|
|
||||||
# Normal V1 situation.
|
# Normal V1 situation.
|
||||||
else:
|
else:
|
||||||
# use chunked prefill for head size 192 scenario, like deepseek
|
output = self._forward_v1_style(query, attn_metadata, output)
|
||||||
# paged_attention_splitfuse maybe crash at such scenario
|
|
||||||
# TODO: vanilla path will be removed after the kernel support
|
|
||||||
# head_size 192 scenario
|
|
||||||
if self.head_size == 192:
|
|
||||||
cu_seqlen_q = [0] + attn_metadata.query_lens.tolist()
|
|
||||||
cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist()
|
|
||||||
cu_seqlen_q = torch.tensor(cu_seqlen_q,
|
|
||||||
device=query.device)
|
|
||||||
cu_seqlen_k = torch.tensor(cu_seqlen_k,
|
|
||||||
device=query.device)
|
|
||||||
cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0)
|
|
||||||
cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0)
|
|
||||||
max_seqlen_q = torch.max(attn_metadata.query_lens)
|
|
||||||
max_seqlen_k = torch.max(attn_metadata.seq_lens)
|
|
||||||
vanilla_chunked_prefill(output, query, self.key_cache,
|
|
||||||
self.value_cache,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
cu_seqlen_q, cu_seqlen_k,
|
|
||||||
max_seqlen_q, max_seqlen_k,
|
|
||||||
self.scale, None, True)
|
|
||||||
else:
|
|
||||||
# use paged attention
|
|
||||||
assert attn_metadata is not None
|
|
||||||
assert attn_metadata.attn_mask is not None
|
|
||||||
if is_310p():
|
|
||||||
# do reformat in case of broadcasted tensors
|
|
||||||
attn_metadata.attn_mask = \
|
|
||||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
attn_metadata.seq_lens = \
|
|
||||||
attn_metadata.seq_lens.to(device=query.device)
|
|
||||||
torch_npu._npu_paged_attention_splitfuse(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
mask=attn_metadata.attn_mask,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
seq_len=attn_metadata.query_lens,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
out=output)
|
|
||||||
|
|
||||||
# to make in-place change to the output tensor
|
# to make in-place change to the output tensor
|
||||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||||
|
|||||||
Reference in New Issue
Block a user