forked from EngineX-MetaX/enginex-c_series-vllm
1499 lines
65 KiB
Python
1499 lines
65 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
"""Attention layer with Dual chunk flash attention and sparse attention.
|
|||
|
|
"""
|
|||
|
|
import math
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.distributed
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
|
|||
|
|
from vllm import _custom_ops as ops
|
|||
|
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
|||
|
|
from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
|
|||
|
|
FlashAttentionImpl,
|
|||
|
|
FlashAttentionMetadata,
|
|||
|
|
FlashAttentionMetadataBuilder)
|
|||
|
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
|||
|
|
from vllm.logger import init_logger
|
|||
|
|
from vllm.utils import async_tensor_h2d
|
|||
|
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
|||
|
|
flash_attn_with_kvcache, sparse_attn_func)
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
|||
|
|
|
|||
|
|
logger = init_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DualChunkFlashAttentionBackend(FlashAttentionBackend):
|
|||
|
|
|
|||
|
|
accept_output_buffer: bool = False
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_name() -> str:
|
|||
|
|
return "DUAL_CHUNK_FLASH_ATTN"
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]:
|
|||
|
|
return DualChunkFlashAttentionImpl
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]:
|
|||
|
|
return DualChunkFlashAttentionMetadata
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]:
|
|||
|
|
return DualChunkFlashAttentionMetadataBuilder
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class DualChunkFlashAttentionMetadata(FlashAttentionMetadata):
|
|||
|
|
# Block size of the paged kv cache.
|
|||
|
|
block_size: int = 16
|
|||
|
|
|
|||
|
|
# Original max position embeddings.
|
|||
|
|
original_max_position_embeddings: int = 0
|
|||
|
|
|
|||
|
|
# Chunk size
|
|||
|
|
chunk_size: int = 8192
|
|||
|
|
|
|||
|
|
# Local size
|
|||
|
|
local_size: int = 1024
|
|||
|
|
|
|||
|
|
# (batch_size,). The orig sequence length per sequence.
|
|||
|
|
orig_seq_lens: Optional[List[int]] = None
|
|||
|
|
|
|||
|
|
# orig_seq_lens stored as a tensor.
|
|||
|
|
orig_seq_lens_tensor: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# Length scaling factor
|
|||
|
|
scaling_factor: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# (batch_size,). Sequence lengths for intra attention.
|
|||
|
|
seq_lens_intra: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# Max sequence length for intra attention.
|
|||
|
|
max_seq_len_intra: Optional[int] = None
|
|||
|
|
|
|||
|
|
# (batch_size, num_blocks). Block table for intra attention.
|
|||
|
|
block_tables_intra: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# (batch_size,). Sequence lengths for succ attention.
|
|||
|
|
seq_lens_succ: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# Max sequence length for succ attention.
|
|||
|
|
max_seq_len_succ: Optional[int] = None
|
|||
|
|
|
|||
|
|
# (batch_size, num_blocks). Block table for succ attention.
|
|||
|
|
block_tables_succ: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# (batch_size,). Sequence lengths for inter attention.
|
|||
|
|
seq_lens_inter: Optional[torch.Tensor] = None
|
|||
|
|
|
|||
|
|
# Max sequence length for inter attention.
|
|||
|
|
max_seq_len_inter: Optional[int] = None
|
|||
|
|
|
|||
|
|
_cached_prefill_metadata: Optional[
|
|||
|
|
"DualChunkFlashAttentionMetadata"] = None
|
|||
|
|
_cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]:
|
|||
|
|
if self.num_prefills == 0:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if self._cached_prefill_metadata is not None:
|
|||
|
|
return self._cached_prefill_metadata
|
|||
|
|
|
|||
|
|
prefill_metadata = super().prefill_metadata
|
|||
|
|
if prefill_metadata is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
prefill_metadata = DualChunkFlashAttentionMetadata(
|
|||
|
|
**prefill_metadata.asdict_zerocopy())
|
|||
|
|
|
|||
|
|
prefill_metadata.orig_seq_lens = (
|
|||
|
|
None if self.orig_seq_lens is None else
|
|||
|
|
self.orig_seq_lens[:self.num_prefills])
|
|||
|
|
prefill_metadata.orig_seq_lens_tensor = (
|
|||
|
|
None if self.orig_seq_lens_tensor is None else
|
|||
|
|
self.orig_seq_lens_tensor[:self.num_prefills])
|
|||
|
|
|
|||
|
|
if self.original_max_position_embeddings > 0:
|
|||
|
|
assert prefill_metadata.orig_seq_lens_tensor is not None
|
|||
|
|
prefill_metadata.scaling_factor = (
|
|||
|
|
0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor /
|
|||
|
|
self.original_max_position_embeddings) +
|
|||
|
|
1.0).clip(min=1)
|
|||
|
|
|
|||
|
|
self._cached_prefill_metadata = prefill_metadata
|
|||
|
|
return prefill_metadata
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]:
|
|||
|
|
if self.num_decode_tokens == 0:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if self._cached_decode_metadata is not None:
|
|||
|
|
return self._cached_decode_metadata
|
|||
|
|
|
|||
|
|
decode_metadata = super().decode_metadata
|
|||
|
|
if decode_metadata is None:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
decode_metadata = DualChunkFlashAttentionMetadata(
|
|||
|
|
**decode_metadata.asdict_zerocopy())
|
|||
|
|
|
|||
|
|
decode_metadata.orig_seq_lens_tensor = (
|
|||
|
|
None if self.orig_seq_lens_tensor is None else
|
|||
|
|
self.orig_seq_lens_tensor[self.num_prefills:])
|
|||
|
|
|
|||
|
|
assert decode_metadata.orig_seq_lens_tensor is not None
|
|||
|
|
assert decode_metadata.block_tables is not None
|
|||
|
|
|
|||
|
|
cache_seq_lens = decode_metadata.orig_seq_lens_tensor
|
|||
|
|
chunk_len = self.chunk_size - self.local_size
|
|||
|
|
chunk_num_curr = (cache_seq_lens - 1) // chunk_len
|
|||
|
|
batch_size = decode_metadata.num_decode_tokens
|
|||
|
|
|
|||
|
|
if self.original_max_position_embeddings > 0:
|
|||
|
|
decode_metadata.scaling_factor = (0.1 * torch.log(
|
|||
|
|
cache_seq_lens / self.original_max_position_embeddings) +
|
|||
|
|
1.0).clip(min=1)
|
|||
|
|
|
|||
|
|
seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len
|
|||
|
|
max_seq_len_intra = seq_lens_intra.max().item()
|
|||
|
|
decode_metadata.seq_lens_intra = seq_lens_intra
|
|||
|
|
decode_metadata.max_seq_len_intra = max_seq_len_intra
|
|||
|
|
|
|||
|
|
block_tables_intra = torch.zeros(
|
|||
|
|
batch_size,
|
|||
|
|
(max_seq_len_intra - 1) // self.block_size + 1,
|
|||
|
|
dtype=decode_metadata.block_tables.dtype,
|
|||
|
|
device=decode_metadata.block_tables.device,
|
|||
|
|
)
|
|||
|
|
for i in range(batch_size):
|
|||
|
|
st = chunk_num_curr[i] * chunk_len // self.block_size
|
|||
|
|
ed = min(
|
|||
|
|
st + (max_seq_len_intra - 1) // self.block_size + 1,
|
|||
|
|
(cache_seq_lens[i] - 1) // self.block_size + 1,
|
|||
|
|
)
|
|||
|
|
block_tables_intra[i, :ed -
|
|||
|
|
st] = decode_metadata.block_tables[i, st:ed]
|
|||
|
|
decode_metadata.block_tables_intra = block_tables_intra
|
|||
|
|
|
|||
|
|
seq_lens_succ = (chunk_num_curr -
|
|||
|
|
(chunk_num_curr - 1).clip(min=0)) * chunk_len
|
|||
|
|
max_seq_len_succ = seq_lens_succ.max().item()
|
|||
|
|
decode_metadata.seq_lens_succ = seq_lens_succ
|
|||
|
|
decode_metadata.max_seq_len_succ = max_seq_len_succ
|
|||
|
|
if max_seq_len_succ:
|
|||
|
|
block_tables_succ = torch.zeros(
|
|||
|
|
batch_size,
|
|||
|
|
(max_seq_len_succ - 1) // self.block_size + 1,
|
|||
|
|
dtype=decode_metadata.block_tables.dtype,
|
|||
|
|
device=decode_metadata.block_tables.device,
|
|||
|
|
)
|
|||
|
|
for i in range(batch_size):
|
|||
|
|
start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len //
|
|||
|
|
self.block_size)
|
|||
|
|
end = min(
|
|||
|
|
start + (max_seq_len_succ - 1) // self.block_size + 1,
|
|||
|
|
(cache_seq_lens[i] - 1) // self.block_size + 1,
|
|||
|
|
)
|
|||
|
|
block_tables_succ[
|
|||
|
|
i, :end - start] = decode_metadata.block_tables[i,
|
|||
|
|
start:end]
|
|||
|
|
decode_metadata.block_tables_succ = block_tables_succ
|
|||
|
|
|
|||
|
|
seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len
|
|||
|
|
max_seq_len_inter = seq_lens_inter.max().item()
|
|||
|
|
decode_metadata.seq_lens_inter = seq_lens_inter
|
|||
|
|
decode_metadata.max_seq_len_inter = max_seq_len_inter
|
|||
|
|
|
|||
|
|
self._cached_decode_metadata = decode_metadata
|
|||
|
|
return decode_metadata
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
|||
|
|
|
|||
|
|
def prepare(self):
|
|||
|
|
super().prepare()
|
|||
|
|
self.orig_seq_lens: List[int] = []
|
|||
|
|
|
|||
|
|
def _add_seq_group(
|
|||
|
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
|||
|
|
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
|||
|
|
super()._add_seq_group(inter_data, chunked_prefill_enabled,
|
|||
|
|
prefix_cache_hit)
|
|||
|
|
for prompt_len, seq_len in zip(inter_data.prompt_lens,
|
|||
|
|
inter_data.seq_lens):
|
|||
|
|
self.orig_seq_lens.append(max(prompt_len, seq_len))
|
|||
|
|
|
|||
|
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
|||
|
|
cuda_graph_pad_size: int, batch_size: int):
|
|||
|
|
attn_metadata = super().build(seq_lens, query_lens,
|
|||
|
|
cuda_graph_pad_size, batch_size)
|
|||
|
|
attn_metadata = DualChunkFlashAttentionMetadata(
|
|||
|
|
**attn_metadata.asdict_zerocopy())
|
|||
|
|
|
|||
|
|
device = self.runner.device
|
|||
|
|
attn_metadata.orig_seq_lens = self.orig_seq_lens
|
|||
|
|
attn_metadata.orig_seq_lens_tensor = async_tensor_h2d(
|
|||
|
|
self.orig_seq_lens, torch.int, device, self.runner.pin_memory)
|
|||
|
|
|
|||
|
|
attn_metadata.block_size = self.runner.block_size
|
|||
|
|
dual_chunk_attn_config = getattr(self.runner.model_config.hf_config,
|
|||
|
|
"dual_chunk_attention_config", {})
|
|||
|
|
attn_metadata.original_max_position_embeddings = \
|
|||
|
|
dual_chunk_attn_config.get("original_max_position_embeddings", 0)
|
|||
|
|
attn_metadata.chunk_size = dual_chunk_attn_config.get(
|
|||
|
|
"chunk_size", 8192)
|
|||
|
|
attn_metadata.local_size = dual_chunk_attn_config.get(
|
|||
|
|
"local_size", 1024)
|
|||
|
|
|
|||
|
|
return attn_metadata
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||
|
|
"""
|
|||
|
|
If the input tensors contain prompt tokens, the layout is as follows:
|
|||
|
|
|<--------------- num_prefill_tokens ----------------->|
|
|||
|
|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
|||
|
|
Otherwise, the layout is as follows:
|
|||
|
|
|<----------------- num_decode_tokens ------------------>|
|
|||
|
|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
|||
|
|
Generation tokens can contain padding when cuda-graph is used.
|
|||
|
|
Currently, prompt tokens don't contain any padding.
|
|||
|
|
The prompts might have different lengths, while the generation tokens
|
|||
|
|
always have length 1.
|
|||
|
|
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
|||
|
|
batched together in a flattened 1D query.
|
|||
|
|
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
|||
|
|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
|||
|
|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
|||
|
|
padding between prefill and decode tokens.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
num_heads: int,
|
|||
|
|
head_size: int,
|
|||
|
|
scale: float,
|
|||
|
|
num_kv_heads: int,
|
|||
|
|
alibi_slopes: Optional[List[float]],
|
|||
|
|
sliding_window: Optional[int],
|
|||
|
|
kv_cache_dtype: str,
|
|||
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|||
|
|
logits_soft_cap: Optional[float] = None,
|
|||
|
|
attn_type: str = AttentionType.DECODER,
|
|||
|
|
kv_sharing_target_layer_name: Optional[str] = None,
|
|||
|
|
layer_idx: int = -1,
|
|||
|
|
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
|||
|
|
) -> None:
|
|||
|
|
if kv_sharing_target_layer_name is not None:
|
|||
|
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
|||
|
|
self.num_heads = num_heads
|
|||
|
|
self.head_size = head_size
|
|||
|
|
self.scale = float(scale)
|
|||
|
|
self.num_kv_heads = num_kv_heads
|
|||
|
|
if alibi_slopes is not None:
|
|||
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
|||
|
|
self.alibi_slopes = alibi_slopes
|
|||
|
|
self.sliding_window = ((sliding_window, sliding_window)
|
|||
|
|
if sliding_window is not None else (-1, -1))
|
|||
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|||
|
|
|
|||
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|||
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|||
|
|
if sliding_window is not None:
|
|||
|
|
# NOTE(woosuk): flash-attn's sliding window does not work with
|
|||
|
|
# paged KV cache.
|
|||
|
|
raise ValueError(
|
|||
|
|
"Sliding window is not supported in FlashAttention.")
|
|||
|
|
|
|||
|
|
support_head_sizes = (
|
|||
|
|
DualChunkFlashAttentionBackend.get_supported_head_sizes())
|
|||
|
|
|
|||
|
|
if head_size not in support_head_sizes:
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Head size {head_size} is not supported by FlashAttention. "
|
|||
|
|
f"Supported head sizes are: {support_head_sizes}.")
|
|||
|
|
|
|||
|
|
assert dual_chunk_attention_config is not None
|
|||
|
|
self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192)
|
|||
|
|
self.local_size = dual_chunk_attention_config.get("local_size", 1024)
|
|||
|
|
self.original_max_position_embeddings = dual_chunk_attention_config.get(
|
|||
|
|
"original_max_position_embeddings", 0)
|
|||
|
|
self.sparse_attention_config = dual_chunk_attention_config.get(
|
|||
|
|
"sparse_attention_config", None)
|
|||
|
|
if not self.sparse_attention_config:
|
|||
|
|
logger.warning_once("Sparse attention will not be enabled as "
|
|||
|
|
"sparse attention config is not provided.")
|
|||
|
|
self.sparse_attention_enabled = dual_chunk_attention_config.get(
|
|||
|
|
"sparse_attention_enabled", self.sparse_attention_config
|
|||
|
|
is not None)
|
|||
|
|
self.sparse_attention_threshold = dual_chunk_attention_config.get(
|
|||
|
|
"sparse_attention_threshold", 32768)
|
|||
|
|
self.sparse_attention_last_q = dual_chunk_attention_config.get(
|
|||
|
|
"sparse_attention_last_q", 64)
|
|||
|
|
self.layer_idx = layer_idx
|
|||
|
|
self.dual_chunk_attention_config = dual_chunk_attention_config
|
|||
|
|
|
|||
|
|
if self.sparse_attention_config:
|
|||
|
|
self.sparse_attention_config = {
|
|||
|
|
int(i): j
|
|||
|
|
for i, j in self.sparse_attention_config[
|
|||
|
|
self.layer_idx].items()
|
|||
|
|
}
|
|||
|
|
start_head = self.num_heads * get_tensor_model_parallel_rank()
|
|||
|
|
end_head = start_head + self.num_heads
|
|||
|
|
self.sparse_attention_config = [
|
|||
|
|
self.sparse_attention_config[i]
|
|||
|
|
for i in range(start_head, end_head)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
if self.sparse_attention_enabled:
|
|||
|
|
self.arange = torch.arange(self.sparse_attention_last_q,
|
|||
|
|
device="cuda")
|
|||
|
|
self.last_q_mask = (self.arange[None, None, :, None]
|
|||
|
|
>= self.arange[None, None, None, :])
|
|||
|
|
|
|||
|
|
def forward( # type: ignore
|
|||
|
|
self,
|
|||
|
|
layer: AttentionLayer,
|
|||
|
|
query: torch.Tensor,
|
|||
|
|
key: torch.Tensor,
|
|||
|
|
value: torch.Tensor,
|
|||
|
|
kv_cache: torch.Tensor,
|
|||
|
|
attn_metadata: DualChunkFlashAttentionMetadata,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""Forward pass with DualChunkFlashAttention.
|
|||
|
|
Args:
|
|||
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|||
|
|
query_succ: shape = [num_tokens, num_heads * head_size]
|
|||
|
|
query_inter: shape = [num_tokens, num_heads * head_size]
|
|||
|
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
|||
|
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
|||
|
|
kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size]
|
|||
|
|
attn_metadata: Metadata for attention.
|
|||
|
|
Returns:
|
|||
|
|
shape = [num_tokens, num_heads * head_size]
|
|||
|
|
"""
|
|||
|
|
(
|
|||
|
|
query,
|
|||
|
|
query_succ,
|
|||
|
|
query_inter,
|
|||
|
|
query_succ_critical,
|
|||
|
|
query_inter_critical,
|
|||
|
|
) = torch.split(query, query.shape[-1] // 5, dim=-1)
|
|||
|
|
|
|||
|
|
assert (
|
|||
|
|
query_succ is not None and query_inter is not None
|
|||
|
|
), "query_succ and query_inter are required in Dual Chunk Attention."
|
|||
|
|
|
|||
|
|
num_tokens, hidden_size = query.shape
|
|||
|
|
|
|||
|
|
# Reshape the query, key, and value tensors.
|
|||
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|||
|
|
query_succ = query_succ.view(-1, self.num_heads, self.head_size)
|
|||
|
|
query_inter = query_inter.view(-1, self.num_heads, self.head_size)
|
|||
|
|
query_succ_critical = query_succ_critical.view(-1, self.num_heads,
|
|||
|
|
self.head_size)
|
|||
|
|
query_inter_critical = query_inter_critical.view(
|
|||
|
|
-1, self.num_heads, self.head_size)
|
|||
|
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|||
|
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|||
|
|
|
|||
|
|
if self.original_max_position_embeddings > 0:
|
|||
|
|
if prefill_meta := attn_metadata.prefill_metadata:
|
|||
|
|
assert prefill_meta.scaling_factor is not None
|
|||
|
|
assert prefill_meta.query_start_loc is not None
|
|||
|
|
assert prefill_meta.orig_seq_lens is not None
|
|||
|
|
current_start = 0
|
|||
|
|
query_start_loc_cpu = prefill_meta.query_start_loc.cpu()
|
|||
|
|
for i in range(len(prefill_meta.orig_seq_lens)):
|
|||
|
|
current_end = (current_start +
|
|||
|
|
(query_start_loc_cpu[i + 1] -
|
|||
|
|
query_start_loc_cpu[i]).item())
|
|||
|
|
key[current_start:current_end].mul_(
|
|||
|
|
prefill_meta.scaling_factor[i])
|
|||
|
|
current_start = current_end
|
|||
|
|
assert current_end <= attn_metadata.num_prefill_tokens
|
|||
|
|
if decode_meta := attn_metadata.decode_metadata:
|
|||
|
|
assert decode_meta.scaling_factor is not None
|
|||
|
|
scaling_factor = decode_meta.scaling_factor
|
|||
|
|
key[attn_metadata.num_prefill_tokens:].mul_(
|
|||
|
|
scaling_factor.unsqueeze(-1).unsqueeze(-1))
|
|||
|
|
|
|||
|
|
if kv_cache is not None and kv_cache.numel() > 0:
|
|||
|
|
key_cache = kv_cache[0]
|
|||
|
|
value_cache = kv_cache[1]
|
|||
|
|
|
|||
|
|
# Reshape the input keys and values and store them in the cache.
|
|||
|
|
# If kv_cache is not provided, the new key and value tensors are
|
|||
|
|
# not cached. This happens during the initial memory profiling run.
|
|||
|
|
ops.reshape_and_cache_flash(
|
|||
|
|
key,
|
|||
|
|
value,
|
|||
|
|
key_cache,
|
|||
|
|
value_cache,
|
|||
|
|
attn_metadata.slot_mapping.flatten(),
|
|||
|
|
self.kv_cache_dtype,
|
|||
|
|
layer._k_scale,
|
|||
|
|
layer._v_scale,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|||
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|||
|
|
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
|||
|
|
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
|||
|
|
output = torch.empty_like(query)
|
|||
|
|
|
|||
|
|
# Query for decode. KV is not needed because it is already cached.
|
|||
|
|
decode_query = query[num_prefill_tokens:]
|
|||
|
|
decode_query_succ = query_succ[num_prefill_tokens:]
|
|||
|
|
decode_query_inter = query_inter[num_prefill_tokens:]
|
|||
|
|
|
|||
|
|
# QKV for prefill.
|
|||
|
|
query = query[:num_prefill_tokens]
|
|||
|
|
query_succ = query_succ[:num_prefill_tokens]
|
|||
|
|
query_inter = query_inter[:num_prefill_tokens]
|
|||
|
|
query_succ_critical = query_succ_critical[:num_prefill_tokens]
|
|||
|
|
query_inter_critical = query_inter_critical[:num_prefill_tokens]
|
|||
|
|
key = key[:num_prefill_tokens]
|
|||
|
|
value = value[:num_prefill_tokens]
|
|||
|
|
assert query.shape[0] == num_prefill_tokens
|
|||
|
|
assert decode_query.shape[0] == num_decode_tokens
|
|||
|
|
|
|||
|
|
if prefill_meta := attn_metadata.prefill_metadata:
|
|||
|
|
# Prompt run.
|
|||
|
|
if (kv_cache is None or prefill_meta.block_tables is None
|
|||
|
|
or prefill_meta.block_tables.numel() == 0):
|
|||
|
|
# normal attention, called during the profiling run.
|
|||
|
|
out = flash_attn_varlen_func(
|
|||
|
|
q=query,
|
|||
|
|
k=key,
|
|||
|
|
v=value,
|
|||
|
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
|||
|
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
|||
|
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
|||
|
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
|||
|
|
softmax_scale=self.scale,
|
|||
|
|
causal=True,
|
|||
|
|
window_size=self.sliding_window,
|
|||
|
|
alibi_slopes=self.alibi_slopes,
|
|||
|
|
)
|
|||
|
|
assert output[:num_prefill_tokens].shape == out.shape
|
|||
|
|
output[:num_prefill_tokens] = out
|
|||
|
|
else:
|
|||
|
|
# prefix-enabled attention
|
|||
|
|
assert prefill_meta.seq_lens is not None
|
|||
|
|
assert prefill_meta.orig_seq_lens is not None
|
|||
|
|
output[:num_prefill_tokens] = (
|
|||
|
|
self._dual_chunk_flash_attn_prefill(
|
|||
|
|
q=query,
|
|||
|
|
q_succ=query_succ,
|
|||
|
|
q_inter=query_inter,
|
|||
|
|
q_succ_critical=query_succ_critical,
|
|||
|
|
q_inter_critical=query_inter_critical,
|
|||
|
|
k=key_cache,
|
|||
|
|
v=value_cache,
|
|||
|
|
cu_seqlens_q=prefill_meta.query_start_loc,
|
|||
|
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
|||
|
|
orig_seq_lens=prefill_meta.orig_seq_lens,
|
|||
|
|
scaling_factor=prefill_meta.scaling_factor,
|
|||
|
|
softmax_scale=self.scale,
|
|||
|
|
causal=True,
|
|||
|
|
window_size=(-1, -1),
|
|||
|
|
alibi_slopes=self.alibi_slopes,
|
|||
|
|
block_table=prefill_meta.block_tables,
|
|||
|
|
chunk_size=self.chunk_size,
|
|||
|
|
local_size=self.local_size,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
if decode_meta := attn_metadata.decode_metadata:
|
|||
|
|
# Decoding run.
|
|||
|
|
output[num_prefill_tokens:] = (
|
|||
|
|
self._dual_chunk_flash_attn_decoding(
|
|||
|
|
decode_query.unsqueeze(1),
|
|||
|
|
decode_query_succ.unsqueeze(1),
|
|||
|
|
decode_query_inter.unsqueeze(1),
|
|||
|
|
key_cache,
|
|||
|
|
value_cache,
|
|||
|
|
block_table=decode_meta.block_tables,
|
|||
|
|
cache_seqlens=decode_meta.seq_lens_tensor,
|
|||
|
|
softmax_scale=self.scale,
|
|||
|
|
causal=True,
|
|||
|
|
alibi_slopes=self.alibi_slopes,
|
|||
|
|
chunk_size=self.chunk_size,
|
|||
|
|
local_size=self.local_size,
|
|||
|
|
original_max_position_embeddings=self.
|
|||
|
|
original_max_position_embeddings,
|
|||
|
|
decode_meta=decode_meta,
|
|||
|
|
).squeeze(1))
|
|||
|
|
# Reshape the output tensor.
|
|||
|
|
return output.view(num_tokens, hidden_size)
|
|||
|
|
|
|||
|
|
def _dual_chunk_flash_attn_prefill(
|
|||
|
|
self,
|
|||
|
|
q,
|
|||
|
|
q_succ,
|
|||
|
|
q_inter,
|
|||
|
|
q_succ_critical,
|
|||
|
|
q_inter_critical,
|
|||
|
|
k,
|
|||
|
|
v,
|
|||
|
|
cu_seqlens_q,
|
|||
|
|
cu_seqlens_k,
|
|||
|
|
orig_seq_lens: List[int],
|
|||
|
|
scaling_factor: torch.Tensor,
|
|||
|
|
softmax_scale: float,
|
|||
|
|
causal: Optional[bool] = True,
|
|||
|
|
window_size: Tuple[int, int] = (-1, -1),
|
|||
|
|
alibi_slopes: Optional[torch.Tensor] = None,
|
|||
|
|
block_table: Optional[torch.Tensor] = None,
|
|||
|
|
chunk_size: int = 8192,
|
|||
|
|
local_size: int = 1024,
|
|||
|
|
):
|
|||
|
|
if alibi_slopes is not None:
|
|||
|
|
raise ValueError(
|
|||
|
|
"Dual Chunk Attention does not support alibi_slopes")
|
|||
|
|
if not causal:
|
|||
|
|
raise ValueError(
|
|||
|
|
"Dual Chunk Attention does not support causal=False")
|
|||
|
|
if window_size != (-1, -1):
|
|||
|
|
raise ValueError(
|
|||
|
|
"Dual Chunk Attention does not support window_size")
|
|||
|
|
|
|||
|
|
cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist()
|
|||
|
|
cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist()
|
|||
|
|
all_outputs = []
|
|||
|
|
|
|||
|
|
for i in range(0, len(cu_seqlens_q_cpu) - 1):
|
|||
|
|
qs = cu_seqlens_q_cpu[i]
|
|||
|
|
qe = cu_seqlens_q_cpu[i:i + 2][-1]
|
|||
|
|
ks = cu_seqlens_k_cpu[i]
|
|||
|
|
ke = cu_seqlens_k_cpu[i:i + 2][-1]
|
|||
|
|
|
|||
|
|
current_q = q[qs:qe]
|
|||
|
|
current_q_succ = q_succ[qs:qe]
|
|||
|
|
current_q_inter = q_inter[qs:qe]
|
|||
|
|
current_q_succ_critical = q_succ_critical[qs:qe]
|
|||
|
|
current_q_inter_critical = q_inter_critical[qs:qe]
|
|||
|
|
|
|||
|
|
if block_table is None:
|
|||
|
|
current_k = k[ks:ke]
|
|||
|
|
current_v = v[ks:ke]
|
|||
|
|
current_block_table = None
|
|||
|
|
current_orig_seq_len = orig_seq_lens[i]
|
|||
|
|
else:
|
|||
|
|
current_block_table = block_table[i]
|
|||
|
|
current_orig_seq_len = orig_seq_lens[i]
|
|||
|
|
current_k = k
|
|||
|
|
current_v = v
|
|||
|
|
sparse_attn_enabled = (self.sparse_attention_enabled
|
|||
|
|
and current_orig_seq_len
|
|||
|
|
> self.sparse_attention_threshold)
|
|||
|
|
|
|||
|
|
if current_q.shape[0] == 0:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if current_k.shape[0] == 0:
|
|||
|
|
all_outputs.append(
|
|||
|
|
torch.zeros(
|
|||
|
|
(current_q.shape[0], current_q.shape[1], v.shape[2]),
|
|||
|
|
device=q.device,
|
|||
|
|
dtype=q.dtype,
|
|||
|
|
))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
current_output = torch.empty_like(current_q)
|
|||
|
|
group_size = int(current_q.size(-2) / current_k.size(-2))
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
num_device_q_heads = current_q.size(-2)
|
|||
|
|
heads_vertical_size = torch.empty(size=(num_device_q_heads, ),
|
|||
|
|
dtype=torch.int32)
|
|||
|
|
heads_slash_size = torch.empty(size=(num_device_q_heads, ),
|
|||
|
|
dtype=torch.int32)
|
|||
|
|
for head_id in range(current_q.size(-2)):
|
|||
|
|
(
|
|||
|
|
ty,
|
|||
|
|
vertical_size,
|
|||
|
|
slash_size,
|
|||
|
|
_,
|
|||
|
|
) = self.sparse_attention_config[head_id]
|
|||
|
|
assert ty == "vertical_and_slash", "only support slash mode"
|
|||
|
|
|
|||
|
|
if vertical_size == 30:
|
|||
|
|
vertical_size += 100
|
|||
|
|
heads_vertical_size[head_id] = vertical_size
|
|||
|
|
heads_slash_size[head_id] = slash_size
|
|||
|
|
|
|||
|
|
current_output = self._dual_chunk_flash_attn_prefill_func(
|
|||
|
|
current_q, # allheads
|
|||
|
|
current_q_succ,
|
|||
|
|
current_q_inter,
|
|||
|
|
current_q_succ_critical,
|
|||
|
|
current_q_inter_critical,
|
|||
|
|
current_k,
|
|||
|
|
current_v,
|
|||
|
|
current_block_table,
|
|||
|
|
softmax_scale,
|
|||
|
|
chunk_size,
|
|||
|
|
local_size,
|
|||
|
|
scaling_factor[i].item(),
|
|||
|
|
ke - ks,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled,
|
|||
|
|
heads_vertical_size=heads_vertical_size,
|
|||
|
|
heads_slash_size=heads_slash_size,
|
|||
|
|
group_size=group_size)
|
|||
|
|
else:
|
|||
|
|
for head_id in range(current_q.size(-2)):
|
|||
|
|
# (seq_len, num_heads, head_size)
|
|||
|
|
current_q_head = current_q[:, head_id, :].unsqueeze(1)
|
|||
|
|
current_q_succ_head = \
|
|||
|
|
current_q_succ[:, head_id, :].unsqueeze(1)
|
|||
|
|
current_q_inter_head = \
|
|||
|
|
current_q_inter[:, head_id, :].unsqueeze(1)
|
|||
|
|
current_q_succ_head_critical = \
|
|||
|
|
current_q_succ_critical[:, head_id, :].unsqueeze(1)
|
|||
|
|
current_q_inter_head_critical = \
|
|||
|
|
current_q_inter_critical[:, head_id, :].unsqueeze(1)
|
|||
|
|
if block_table is not None:
|
|||
|
|
current_k_head = current_k[..., head_id //
|
|||
|
|
group_size, :].unsqueeze(2)
|
|||
|
|
current_v_head = current_v[..., head_id //
|
|||
|
|
group_size, :].unsqueeze(2)
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
current_k_head = current_k[:, head_id, :].unsqueeze(1)
|
|||
|
|
current_v_head = current_v[:, head_id, :].unsqueeze(1)
|
|||
|
|
|
|||
|
|
current_out = self._dual_chunk_flash_attn_prefill_func(
|
|||
|
|
current_q_head,
|
|||
|
|
current_q_succ_head,
|
|||
|
|
current_q_inter_head,
|
|||
|
|
current_q_succ_head_critical,
|
|||
|
|
current_q_inter_head_critical,
|
|||
|
|
current_k_head,
|
|||
|
|
current_v_head,
|
|||
|
|
current_block_table,
|
|||
|
|
softmax_scale,
|
|||
|
|
chunk_size,
|
|||
|
|
local_size,
|
|||
|
|
scaling_factor[i].item(),
|
|||
|
|
ke - ks,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled,
|
|||
|
|
)
|
|||
|
|
current_output[:, head_id:head_id + 1, :] = current_out
|
|||
|
|
all_outputs.append(current_output)
|
|||
|
|
return torch.cat(all_outputs, dim=0)
|
|||
|
|
|
|||
|
|
def _dual_chunk_flash_attn_prefill_func(
|
|||
|
|
self,
|
|||
|
|
q,
|
|||
|
|
q_succ,
|
|||
|
|
q_inter,
|
|||
|
|
q_succ_critical,
|
|||
|
|
q_inter_critical,
|
|||
|
|
k,
|
|||
|
|
v,
|
|||
|
|
block_table,
|
|||
|
|
softmax_scale: float,
|
|||
|
|
chunk_size: int,
|
|||
|
|
local_size: int,
|
|||
|
|
scaling_factor: float,
|
|||
|
|
k_length: int,
|
|||
|
|
sparse_attn_enabled: Optional[bool] = True,
|
|||
|
|
heads_vertical_size=None,
|
|||
|
|
heads_slash_size=None,
|
|||
|
|
group_size=None,
|
|||
|
|
):
|
|||
|
|
flash_results = []
|
|||
|
|
chunk_len = chunk_size - local_size
|
|||
|
|
|
|||
|
|
if block_table is not None:
|
|||
|
|
block_size = v.shape[1]
|
|||
|
|
if chunk_len % block_size != 0:
|
|||
|
|
raise ValueError("chunk_len must be divisible by block_size.")
|
|||
|
|
else:
|
|||
|
|
block_size = 1
|
|||
|
|
|
|||
|
|
if self.original_max_position_embeddings > 0:
|
|||
|
|
softmax_scale = softmax_scale * scaling_factor
|
|||
|
|
|
|||
|
|
begin = k_length - q.shape[0]
|
|||
|
|
while begin < k_length:
|
|||
|
|
flash_per_chunk = []
|
|||
|
|
|
|||
|
|
prev_chunk_end_pos = (begin // chunk_len) * chunk_len
|
|||
|
|
next_chunk_end_pos = prev_chunk_end_pos + chunk_len
|
|||
|
|
end = min(next_chunk_end_pos, k_length)
|
|||
|
|
qbegin = begin - (k_length - q.shape[0])
|
|||
|
|
qend = end - (k_length - q.shape[0])
|
|||
|
|
|
|||
|
|
qk_chunks = []
|
|||
|
|
q_states_intra = q[qbegin:qend]
|
|||
|
|
# choose critical token
|
|||
|
|
if block_table is not None:
|
|||
|
|
block_tables_intra = _get_block(block_table, block_size,
|
|||
|
|
prev_chunk_end_pos, end)
|
|||
|
|
k_states_intra = k[block_tables_intra].view(
|
|||
|
|
-1, *k.shape[-2:])[:(end - prev_chunk_end_pos)]
|
|||
|
|
v_states_intra = v[block_tables_intra].view(
|
|||
|
|
-1, *v.shape[-2:])[:(end - prev_chunk_end_pos)]
|
|||
|
|
else:
|
|||
|
|
block_tables_intra = None
|
|||
|
|
k_states_intra = k[prev_chunk_end_pos:end]
|
|||
|
|
v_states_intra = v[prev_chunk_end_pos:end]
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
last_q_size = min(qend - qbegin, self.sparse_attention_last_q)
|
|||
|
|
_, num_device_k_heads, head_dim = k_states_intra.shape
|
|||
|
|
k_states_intra = (k_states_intra.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size, head_dim))
|
|||
|
|
v_states_intra = (v_states_intra.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size, head_dim))
|
|||
|
|
qk_chunks.append(
|
|||
|
|
(q_states_intra.transpose(0, 1)[:, -last_q_size:] *
|
|||
|
|
softmax_scale) @ k_states_intra.permute(1, 2, 0))
|
|||
|
|
|
|||
|
|
if prev_chunk_end_pos - chunk_len >= 0:
|
|||
|
|
q_states_succ = q_succ[qbegin:qend]
|
|||
|
|
q_states_succ_critical = q_succ_critical[qbegin:qend]
|
|||
|
|
if block_table is not None:
|
|||
|
|
block_tables_succ = _get_block(
|
|||
|
|
block_table, block_size,
|
|||
|
|
prev_chunk_end_pos - chunk_len, prev_chunk_end_pos)
|
|||
|
|
k_states_succ = k[block_tables_succ].view(
|
|||
|
|
-1, *k.shape[-2:])[:chunk_len]
|
|||
|
|
v_states_succ = v[block_tables_succ].view(
|
|||
|
|
-1, *v.shape[-2:])[:chunk_len]
|
|||
|
|
else:
|
|||
|
|
k_states_succ = k[prev_chunk_end_pos -
|
|||
|
|
chunk_len:prev_chunk_end_pos]
|
|||
|
|
v_states_succ = v[prev_chunk_end_pos -
|
|||
|
|
chunk_len:prev_chunk_end_pos]
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
k_states_succ = (k_states_succ.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size,
|
|||
|
|
head_dim))
|
|||
|
|
v_states_succ = (v_states_succ.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size,
|
|||
|
|
head_dim))
|
|||
|
|
qk_chunks.append((q_states_succ_critical.transpose(
|
|||
|
|
0, 1)[:, -last_q_size:] * softmax_scale)
|
|||
|
|
@ k_states_succ.permute(1, 2, 0))
|
|||
|
|
|
|||
|
|
if prev_chunk_end_pos - chunk_len * 2 >= 0:
|
|||
|
|
q_states_inter = q_inter[qbegin:qend]
|
|||
|
|
q_states_inter_critical = q_inter_critical[qbegin:qend]
|
|||
|
|
if block_table is not None:
|
|||
|
|
block_tables_inter = _get_block(
|
|||
|
|
block_table, block_size, 0,
|
|||
|
|
prev_chunk_end_pos - chunk_len)
|
|||
|
|
k_states_inter = k[block_tables_inter].view(
|
|||
|
|
-1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
|
|||
|
|
v_states_inter = v[block_tables_inter].view(
|
|||
|
|
-1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
|
|||
|
|
else:
|
|||
|
|
k_states_inter = k[:prev_chunk_end_pos - chunk_len]
|
|||
|
|
v_states_inter = v[:prev_chunk_end_pos - chunk_len]
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
k_states_inter = (k_states_inter.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size,
|
|||
|
|
head_dim))
|
|||
|
|
v_states_inter = (v_states_inter.unsqueeze(2).repeat(
|
|||
|
|
1, 1, group_size,
|
|||
|
|
1).reshape(-1, num_device_k_heads * group_size,
|
|||
|
|
head_dim))
|
|||
|
|
qk_chunks.append((q_states_inter_critical.transpose(
|
|||
|
|
0, 1)[:, -last_q_size:] * softmax_scale)
|
|||
|
|
@ k_states_inter.permute(1, 2, 0))
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
reversed_qk = qk_chunks[::-1]
|
|||
|
|
qk = torch.cat(reversed_qk, dim=-1)
|
|||
|
|
|
|||
|
|
qk[:, :, -last_q_size:] = torch.where(
|
|||
|
|
self.last_q_mask[..., -last_q_size:,
|
|||
|
|
-last_q_size:].to(qk.device),
|
|||
|
|
qk[:, :, -last_q_size:], -torch.inf)
|
|||
|
|
qk = F.softmax(qk, dim=-1, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
vertical = qk.sum(-2, keepdim=True)
|
|||
|
|
vertical[..., :30] = torch.inf
|
|||
|
|
|
|||
|
|
# Avoid sorting by using the min/max ints to fill the indexer
|
|||
|
|
# buffers.
|
|||
|
|
int32_max = torch.iinfo(torch.int32).max
|
|||
|
|
int32_min = torch.iinfo(torch.int32).min
|
|||
|
|
n_heads = qk.size()[0]
|
|||
|
|
max_slash_topk = torch.max(heads_slash_size).item()
|
|||
|
|
max_vertical_topk = torch.max(heads_vertical_size).item()
|
|||
|
|
# store each head's slash topk, vertical topk
|
|||
|
|
vertical = vertical.reshape((n_heads, -1))
|
|||
|
|
# prevent out of range when prompt size < max_vertical_topk
|
|||
|
|
max_vertical_topk = min(vertical.shape[-1], max_vertical_topk)
|
|||
|
|
vertical_topk_buffer = torch.topk(vertical, max_vertical_topk,
|
|||
|
|
-1).indices
|
|||
|
|
slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk),
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=qk.device)
|
|||
|
|
for head_i in range(n_heads):
|
|||
|
|
# (nqheads=1, lastq, k_len)
|
|||
|
|
head_score = qk[head_i:head_i + 1, :, :]
|
|||
|
|
slash_scores = _sum_all_diagonal_matrix(head_score)
|
|||
|
|
if head_score.size(1) != 1:
|
|||
|
|
# drop right up corner
|
|||
|
|
slash_scores = slash_scores[..., :-last_q_size + 1]
|
|||
|
|
slash_scores[..., -100:] = torch.inf
|
|||
|
|
|
|||
|
|
head_slash_size = heads_slash_size[head_i]
|
|||
|
|
head_slash_size = min(head_slash_size, vertical.size(-1))
|
|||
|
|
slash_topk = torch.topk(slash_scores, head_slash_size,
|
|||
|
|
-1).indices
|
|||
|
|
#(nheads, max_topk)
|
|||
|
|
slash_topk_buffer[head_i, :head_slash_size] = slash_topk
|
|||
|
|
|
|||
|
|
# reset heads topk
|
|||
|
|
heads_slash_size[head_i] = head_slash_size
|
|||
|
|
heads_vertical_size[head_i] = min(
|
|||
|
|
heads_vertical_size[head_i], max_vertical_topk)
|
|||
|
|
|
|||
|
|
# store
|
|||
|
|
vertical_buffer = torch.full((n_heads, max_vertical_topk),
|
|||
|
|
int32_max,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
slash_buffer = torch.full((n_heads, max_slash_topk),
|
|||
|
|
int32_min,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
succ_vertical_buffer = torch.full((n_heads, max_vertical_topk),
|
|||
|
|
int32_max,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
succ_slash_buffer = torch.full((n_heads, max_slash_topk),
|
|||
|
|
int32_min,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
inter_vertical_buffer = torch.full(
|
|||
|
|
(n_heads, max_vertical_topk),
|
|||
|
|
int32_max,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
inter_slash_buffer = torch.full((n_heads, max_slash_topk),
|
|||
|
|
int32_min,
|
|||
|
|
dtype=torch.int64,
|
|||
|
|
device=q.device)
|
|||
|
|
|
|||
|
|
vertical_size_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
slash_sizes_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
succ_vertical_size_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
succ_slash_sizes_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
inter_vertical_size_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
inter_slash_sizes_buffer = torch.empty(size=(n_heads, ),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=q.device)
|
|||
|
|
|
|||
|
|
for head_i in range(n_heads):
|
|||
|
|
vertical_topk = vertical_topk_buffer[
|
|||
|
|
head_i, :heads_vertical_size[head_i]]
|
|||
|
|
# intra
|
|||
|
|
intra_vertical_indices = vertical_topk[
|
|||
|
|
vertical_topk >=
|
|||
|
|
prev_chunk_end_pos] - prev_chunk_end_pos
|
|||
|
|
if intra_vertical_indices.nelement() == 0:
|
|||
|
|
intra_vertical_indices = torch.cat([
|
|||
|
|
intra_vertical_indices,
|
|||
|
|
torch.arange(0,
|
|||
|
|
k_states_intra.size(0),
|
|||
|
|
max(1,
|
|||
|
|
k_states_intra.size(0) / 5),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=intra_vertical_indices.device)
|
|||
|
|
])
|
|||
|
|
slash_topk = slash_topk_buffer[
|
|||
|
|
head_i, :heads_slash_size[head_i]]
|
|||
|
|
intra_slash_indices = (
|
|||
|
|
(qk.size(-1) - 1) -
|
|||
|
|
slash_topk[slash_topk >= prev_chunk_end_pos])
|
|||
|
|
# fill buffer
|
|||
|
|
v_count = intra_vertical_indices.nelement()
|
|||
|
|
s_count = intra_slash_indices.nelement()
|
|||
|
|
vertical_size_buffer[head_i] = v_count
|
|||
|
|
slash_sizes_buffer[head_i] = s_count
|
|||
|
|
vertical_buffer[head_i, :v_count].copy_(
|
|||
|
|
intra_vertical_indices)
|
|||
|
|
slash_buffer[head_i, :s_count].copy_(intra_slash_indices)
|
|||
|
|
# succ
|
|||
|
|
if prev_chunk_end_pos - chunk_len >= 0:
|
|||
|
|
succ_vertical_indices = vertical_topk[
|
|||
|
|
(vertical_topk < prev_chunk_end_pos)
|
|||
|
|
& (vertical_topk >= prev_chunk_end_pos -
|
|||
|
|
chunk_len)] - (prev_chunk_end_pos - chunk_len)
|
|||
|
|
# TODO: support no vertical
|
|||
|
|
if succ_vertical_indices.nelement() == 0:
|
|||
|
|
succ_vertical_indices = torch.cat([
|
|||
|
|
succ_vertical_indices,
|
|||
|
|
torch.arange(
|
|||
|
|
0,
|
|||
|
|
k_states_succ.size(0),
|
|||
|
|
max(1,
|
|||
|
|
k_states_succ.size(0) / 5),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=intra_vertical_indices.device)
|
|||
|
|
])
|
|||
|
|
succ_slash_indices = (
|
|||
|
|
(prev_chunk_end_pos + (qend - qbegin) - 1) -
|
|||
|
|
slash_topk[((slash_topk >=
|
|||
|
|
(prev_chunk_end_pos - chunk_len)) &
|
|||
|
|
(slash_topk < (prev_chunk_end_pos +
|
|||
|
|
(qend - qbegin))))])
|
|||
|
|
if succ_slash_indices.nelement() == 0:
|
|||
|
|
succ_slash_indices = torch.cat([
|
|||
|
|
succ_slash_indices,
|
|||
|
|
torch.arange(
|
|||
|
|
0,
|
|||
|
|
k_states_succ.size(0),
|
|||
|
|
max(1,
|
|||
|
|
k_states_succ.size(0) / 5),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=intra_vertical_indices.device)
|
|||
|
|
])
|
|||
|
|
# fill buffer
|
|||
|
|
v_count = succ_vertical_indices.nelement()
|
|||
|
|
s_count = succ_slash_indices.nelement()
|
|||
|
|
succ_vertical_size_buffer[head_i] = v_count
|
|||
|
|
succ_slash_sizes_buffer[head_i] = s_count
|
|||
|
|
succ_vertical_buffer[head_i, :v_count].copy_(
|
|||
|
|
succ_vertical_indices)
|
|||
|
|
succ_slash_buffer[head_i, :s_count].copy_(
|
|||
|
|
succ_slash_indices)
|
|||
|
|
|
|||
|
|
if prev_chunk_end_pos - 2 * chunk_len >= 0:
|
|||
|
|
inter_vertical_indices = vertical_topk[
|
|||
|
|
vertical_topk < prev_chunk_end_pos - chunk_len]
|
|||
|
|
|
|||
|
|
if inter_vertical_indices.nelement() == 0:
|
|||
|
|
inter_vertical_indices = torch.cat([
|
|||
|
|
inter_vertical_indices,
|
|||
|
|
torch.arange(
|
|||
|
|
0,
|
|||
|
|
k_states_inter.size(0),
|
|||
|
|
max(1,
|
|||
|
|
k_states_inter.size(0) / 5),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=intra_vertical_indices.device)
|
|||
|
|
])
|
|||
|
|
inter_slash_indices = (
|
|||
|
|
(prev_chunk_end_pos - chunk_len +
|
|||
|
|
(qend - qbegin) - 1) -
|
|||
|
|
slash_topk[slash_topk < (prev_chunk_end_pos -
|
|||
|
|
chunk_len +
|
|||
|
|
(qend - qbegin))])
|
|||
|
|
if inter_slash_indices.nelement() == 0:
|
|||
|
|
inter_slash_indices = torch.cat([
|
|||
|
|
inter_slash_indices,
|
|||
|
|
torch.arange(
|
|||
|
|
0,
|
|||
|
|
k_states_inter.size(0),
|
|||
|
|
max(1,
|
|||
|
|
k_states_inter.size(0) / 5),
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=intra_vertical_indices.device)
|
|||
|
|
])
|
|||
|
|
# fill buffer
|
|||
|
|
v_count = inter_vertical_indices.nelement()
|
|||
|
|
s_count = inter_slash_indices.nelement()
|
|||
|
|
inter_vertical_size_buffer[head_i] = v_count
|
|||
|
|
inter_slash_sizes_buffer[head_i] = s_count
|
|||
|
|
inter_vertical_buffer[head_i, :v_count].copy_(
|
|||
|
|
inter_vertical_indices)
|
|||
|
|
inter_slash_buffer[head_i, :s_count].copy_(
|
|||
|
|
inter_slash_indices)
|
|||
|
|
else:
|
|||
|
|
intra_vertical_indices, intra_slash_indices = None, None
|
|||
|
|
succ_vertical_indices, succ_slash_indices = None, None
|
|||
|
|
inter_vertical_indices, inter_slash_indices = None, None
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_intra,
|
|||
|
|
k_states_intra,
|
|||
|
|
v_states_intra,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=True,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="intra",
|
|||
|
|
vertical_indices=vertical_buffer,
|
|||
|
|
slash_indices=slash_buffer,
|
|||
|
|
vertical_indices_count=vertical_size_buffer,
|
|||
|
|
slash_indices_count=slash_sizes_buffer,
|
|||
|
|
mergehead_softmax_scale=softmax_scale,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
else:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_intra,
|
|||
|
|
k_states_intra,
|
|||
|
|
v_states_intra,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=True,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="intra",
|
|||
|
|
vertical_indices=intra_vertical_indices,
|
|||
|
|
slash_indices=intra_slash_indices,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
flash_per_chunk.append(flash_result)
|
|||
|
|
|
|||
|
|
if prev_chunk_end_pos - chunk_len >= 0:
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_succ,
|
|||
|
|
k_states_succ,
|
|||
|
|
v_states_succ,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=False,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="succ",
|
|||
|
|
vertical_indices=succ_vertical_buffer,
|
|||
|
|
slash_indices=succ_slash_buffer,
|
|||
|
|
vertical_indices_count=succ_vertical_size_buffer,
|
|||
|
|
slash_indices_count=succ_slash_sizes_buffer,
|
|||
|
|
mergehead_softmax_scale=softmax_scale,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
else:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_succ,
|
|||
|
|
k_states_succ,
|
|||
|
|
v_states_succ,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=False,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="succ",
|
|||
|
|
vertical_indices=succ_vertical_indices,
|
|||
|
|
slash_indices=succ_slash_indices,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
flash_per_chunk.append(flash_result)
|
|||
|
|
|
|||
|
|
if prev_chunk_end_pos - chunk_len * 2 >= 0:
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_inter,
|
|||
|
|
k_states_inter,
|
|||
|
|
v_states_inter,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=False,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="inter",
|
|||
|
|
vertical_indices=inter_vertical_buffer,
|
|||
|
|
slash_indices=inter_slash_buffer,
|
|||
|
|
vertical_indices_count=inter_vertical_size_buffer,
|
|||
|
|
slash_indices_count=inter_slash_sizes_buffer,
|
|||
|
|
mergehead_softmax_scale=softmax_scale,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
else:
|
|||
|
|
flash_result = self._do_flash_attn(
|
|||
|
|
q_states_inter,
|
|||
|
|
k_states_inter,
|
|||
|
|
v_states_inter,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
causal=False,
|
|||
|
|
block_table=block_table,
|
|||
|
|
stage="inter",
|
|||
|
|
vertical_indices=inter_vertical_indices,
|
|||
|
|
slash_indices=inter_slash_indices,
|
|||
|
|
sparse_attn_enabled=sparse_attn_enabled)
|
|||
|
|
flash_per_chunk.append(flash_result)
|
|||
|
|
|
|||
|
|
flash_results.append(flash_per_chunk)
|
|||
|
|
begin = end
|
|||
|
|
|
|||
|
|
attn_output = self._merge_attn_outputs(flash_results)
|
|||
|
|
del flash_results
|
|||
|
|
return attn_output
|
|||
|
|
|
|||
|
|
def _do_flash_attn(
|
|||
|
|
self,
|
|||
|
|
query_states: torch.Tensor,
|
|||
|
|
key_states: torch.Tensor,
|
|||
|
|
value_states: torch.Tensor,
|
|||
|
|
softmax_scale: float,
|
|||
|
|
causal: bool = True,
|
|||
|
|
block_table: torch.Tensor = None,
|
|||
|
|
max_seqlen_k: Optional[int] = None,
|
|||
|
|
stage: str = "intra",
|
|||
|
|
vertical_indices: Optional[torch.Tensor] = None,
|
|||
|
|
slash_indices: Optional[torch.Tensor] = None,
|
|||
|
|
vertical_indices_count: Optional[torch.Tensor] = None,
|
|||
|
|
slash_indices_count: Optional[torch.Tensor] = None,
|
|||
|
|
mergehead_softmax_scale: Optional[float] = None,
|
|||
|
|
sparse_attn_enabled: Optional[bool] = False,
|
|||
|
|
):
|
|||
|
|
if max_seqlen_k is None:
|
|||
|
|
max_seqlen_k = key_states.shape[0]
|
|||
|
|
|
|||
|
|
q_len = query_states.shape[0]
|
|||
|
|
q_heads = query_states.shape[1]
|
|||
|
|
h_dim = query_states.shape[-1]
|
|||
|
|
|
|||
|
|
if sparse_attn_enabled:
|
|||
|
|
assert slash_indices is not None
|
|||
|
|
if stage == "intra":
|
|||
|
|
assert causal
|
|||
|
|
else:
|
|||
|
|
assert not causal
|
|||
|
|
|
|||
|
|
query_states = query_states.unsqueeze(0).transpose(1, 2)
|
|||
|
|
key_states = key_states.unsqueeze(0).transpose(1, 2)
|
|||
|
|
value_states = value_states.unsqueeze(0).transpose(1, 2)
|
|||
|
|
|
|||
|
|
q = query_states
|
|||
|
|
k = key_states
|
|||
|
|
v = value_states
|
|||
|
|
|
|||
|
|
if (vertical_indices_count is not None and \
|
|||
|
|
slash_indices_count is not None):
|
|||
|
|
assert mergehead_softmax_scale is not None
|
|||
|
|
|
|||
|
|
res, s_lse = _vertical_slash_sparse_attention(
|
|||
|
|
q,
|
|||
|
|
k,
|
|||
|
|
v,
|
|||
|
|
vertical_indices,
|
|||
|
|
slash_indices,
|
|||
|
|
mergehead_softmax_scale,
|
|||
|
|
causal=causal,
|
|||
|
|
stage=stage,
|
|||
|
|
vertical_indices_count=vertical_indices_count,
|
|||
|
|
slash_indices_count=slash_indices_count)
|
|||
|
|
res = res.view(q_heads, q_len,
|
|||
|
|
h_dim).transpose(0, 1) # (qlen,nhead,h_dim)
|
|||
|
|
s_lse = s_lse.view(
|
|||
|
|
q_heads, q_len,
|
|||
|
|
1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen)
|
|||
|
|
else:
|
|||
|
|
res, s_lse = _vertical_slash_sparse_attention(q,
|
|||
|
|
k,
|
|||
|
|
v,
|
|||
|
|
vertical_indices,
|
|||
|
|
slash_indices,
|
|||
|
|
softmax_scale,
|
|||
|
|
causal=causal,
|
|||
|
|
stage=stage)
|
|||
|
|
res = res.view(q_len, q_heads, h_dim)
|
|||
|
|
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
|
|||
|
|
return res, s_lse
|
|||
|
|
|
|||
|
|
output, softmax_lse = flash_attn_varlen_func(
|
|||
|
|
q=query_states,
|
|||
|
|
k=key_states,
|
|||
|
|
v=value_states,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=query_states.device),
|
|||
|
|
max_seqlen_q=query_states.shape[0],
|
|||
|
|
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=query_states.device),
|
|||
|
|
max_seqlen_k=max_seqlen_k,
|
|||
|
|
causal=causal,
|
|||
|
|
block_table=block_table.unsqueeze(0),
|
|||
|
|
return_softmax_lse=True,
|
|||
|
|
)
|
|||
|
|
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
|
|||
|
|
2).float()
|
|||
|
|
return output, softmax_lse
|
|||
|
|
|
|||
|
|
def _merge_attn_outputs(
|
|||
|
|
self,
|
|||
|
|
flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]],
|
|||
|
|
return_lse: Optional[bool] = False,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
attn_outputs_all = []
|
|||
|
|
logits_all = []
|
|||
|
|
|
|||
|
|
for flash_per_chunk in flash_results:
|
|||
|
|
if len(flash_per_chunk) == 1:
|
|||
|
|
attn_outputs_all.append(flash_per_chunk[0][0])
|
|||
|
|
if return_lse:
|
|||
|
|
logits_all.append(flash_per_chunk[0][1])
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
attn_outputs = torch.stack([
|
|||
|
|
flash_attn_output[0] for flash_attn_output in flash_per_chunk
|
|||
|
|
])
|
|||
|
|
logits = torch.stack([
|
|||
|
|
flash_attn_output[1] for flash_attn_output in flash_per_chunk
|
|||
|
|
])
|
|||
|
|
logits = logits.to(torch.float32)
|
|||
|
|
|
|||
|
|
if return_lse:
|
|||
|
|
max_val = torch.max(logits, dim=0).values
|
|||
|
|
diff = torch.abs(logits[0] - logits[1])
|
|||
|
|
log_sum_exp = max_val + torch.log1p(torch.exp(-diff))
|
|||
|
|
logits_all.append(log_sum_exp)
|
|||
|
|
|
|||
|
|
max_logits = torch.max(logits, dim=0).values
|
|||
|
|
stable_logits = logits - max_logits.unsqueeze(0)
|
|||
|
|
lse_s = torch.exp(stable_logits).detach()
|
|||
|
|
lse_sum = torch.sum(lse_s, dim=0)
|
|||
|
|
lse_s /= lse_sum
|
|||
|
|
attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1)
|
|||
|
|
attn_outputs_all.append(attn_outputs.sum(dim=0))
|
|||
|
|
|
|||
|
|
if return_lse:
|
|||
|
|
return (torch.cat(attn_outputs_all,
|
|||
|
|
dim=0), torch.cat(logits_all, dim=-1))
|
|||
|
|
else:
|
|||
|
|
return torch.cat(attn_outputs_all, dim=0)
|
|||
|
|
|
|||
|
|
def _dual_chunk_flash_attn_decoding(
|
|||
|
|
self,
|
|||
|
|
query: torch.Tensor,
|
|||
|
|
query_succ: torch.Tensor,
|
|||
|
|
query_inter: torch.Tensor,
|
|||
|
|
key_cache: torch.Tensor,
|
|||
|
|
value_cache: torch.Tensor,
|
|||
|
|
block_table: torch.Tensor,
|
|||
|
|
cache_seqlens: torch.Tensor,
|
|||
|
|
softmax_scale: float,
|
|||
|
|
causal: bool,
|
|||
|
|
alibi_slopes: Optional[torch.Tensor],
|
|||
|
|
chunk_size: int,
|
|||
|
|
local_size: int,
|
|||
|
|
original_max_position_embeddings: int,
|
|||
|
|
decode_meta: DualChunkFlashAttentionMetadata,
|
|||
|
|
):
|
|||
|
|
if not causal:
|
|||
|
|
raise ValueError(
|
|||
|
|
"Dual Chunk Attention does not support causal=False")
|
|||
|
|
|
|||
|
|
block_size = value_cache.shape[1]
|
|||
|
|
chunk_len = chunk_size - local_size
|
|||
|
|
if chunk_len % block_size != 0:
|
|||
|
|
raise ValueError("chunk_len must be divisible by block_size.")
|
|||
|
|
if original_max_position_embeddings > 0:
|
|||
|
|
assert decode_meta.scaling_factor is not None
|
|||
|
|
scaling_factor = decode_meta.scaling_factor
|
|||
|
|
query = (query * scaling_factor.view(-1, 1, 1, 1)).to(
|
|||
|
|
query.dtype
|
|||
|
|
) # possible for numerical issue, need to fused in the kernel
|
|||
|
|
query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(
|
|||
|
|
query.dtype)
|
|||
|
|
query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to(
|
|||
|
|
query.dtype)
|
|||
|
|
outputs_list = []
|
|||
|
|
softmax_lses_list = []
|
|||
|
|
|
|||
|
|
# intra-attention
|
|||
|
|
intra_output, intra_softmax_lse = (
|
|||
|
|
self._dual_chunk_flash_attn_decoding_with_exp_sums(
|
|||
|
|
query,
|
|||
|
|
key_cache,
|
|||
|
|
value_cache,
|
|||
|
|
decode_meta.block_tables_intra,
|
|||
|
|
decode_meta.seq_lens_intra,
|
|||
|
|
softmax_scale,
|
|||
|
|
alibi_slopes,
|
|||
|
|
causal=False,
|
|||
|
|
))
|
|||
|
|
outputs_list.append(intra_output)
|
|||
|
|
softmax_lses_list.append(intra_softmax_lse)
|
|||
|
|
|
|||
|
|
# succ-attention
|
|||
|
|
if decode_meta.max_seq_len_succ:
|
|||
|
|
succ_output, succ_softmax_lse = (
|
|||
|
|
self._dual_chunk_flash_attn_decoding_with_exp_sums(
|
|||
|
|
query_succ,
|
|||
|
|
key_cache,
|
|||
|
|
value_cache,
|
|||
|
|
decode_meta.block_tables_succ,
|
|||
|
|
decode_meta.seq_lens_succ,
|
|||
|
|
softmax_scale,
|
|||
|
|
alibi_slopes,
|
|||
|
|
causal=False,
|
|||
|
|
))
|
|||
|
|
outputs_list.append(succ_output)
|
|||
|
|
softmax_lses_list.append(succ_softmax_lse)
|
|||
|
|
|
|||
|
|
# inter-attention
|
|||
|
|
if decode_meta.max_seq_len_inter:
|
|||
|
|
inter_output, inter_softmax_lse = (
|
|||
|
|
self._dual_chunk_flash_attn_decoding_with_exp_sums(
|
|||
|
|
query_inter,
|
|||
|
|
key_cache,
|
|||
|
|
value_cache,
|
|||
|
|
block_table[:, :decode_meta.max_seq_len_inter],
|
|||
|
|
decode_meta.seq_lens_inter,
|
|||
|
|
softmax_scale,
|
|||
|
|
alibi_slopes,
|
|||
|
|
causal=False,
|
|||
|
|
))
|
|||
|
|
outputs_list.append(inter_output)
|
|||
|
|
softmax_lses_list.append(inter_softmax_lse)
|
|||
|
|
outputs = torch.stack(outputs_list, dim=0)
|
|||
|
|
del outputs_list
|
|||
|
|
softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32)
|
|||
|
|
del softmax_lses_list
|
|||
|
|
max_logits = torch.max(softmax_lses, dim=0).values
|
|||
|
|
stable_logits = softmax_lses - max_logits.unsqueeze(0)
|
|||
|
|
lse_s = torch.exp(stable_logits).detach()
|
|||
|
|
lse_sum = torch.sum(lse_s, dim=0)
|
|||
|
|
lse_s /= lse_sum
|
|||
|
|
outputs *= lse_s.unsqueeze(-1).transpose(2, 3)
|
|||
|
|
return outputs.sum(0)
|
|||
|
|
|
|||
|
|
def _dual_chunk_flash_attn_decoding_with_exp_sums(
|
|||
|
|
self,
|
|||
|
|
query: torch.Tensor,
|
|||
|
|
key_cache: torch.Tensor,
|
|||
|
|
value_cache: torch.Tensor,
|
|||
|
|
block_table: torch.Tensor,
|
|||
|
|
cache_seqlens: torch.Tensor,
|
|||
|
|
softmax_scale: float,
|
|||
|
|
alibi_slopes: Optional[torch.Tensor],
|
|||
|
|
causal: bool,
|
|||
|
|
):
|
|||
|
|
out, softmax_lse = flash_attn_with_kvcache(
|
|||
|
|
q=query,
|
|||
|
|
k_cache=key_cache,
|
|||
|
|
v_cache=value_cache,
|
|||
|
|
block_table=block_table,
|
|||
|
|
cache_seqlens=cache_seqlens,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
alibi_slopes=alibi_slopes,
|
|||
|
|
causal=causal,
|
|||
|
|
return_softmax_lse=True,
|
|||
|
|
)
|
|||
|
|
mask = (cache_seqlens == 0)
|
|||
|
|
out[mask] = 0
|
|||
|
|
softmax_lse[mask] = -float("inf")
|
|||
|
|
return out, softmax_lse
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _vertical_slash_sparse_attention(
|
|||
|
|
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
|
|||
|
|
key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
|
|||
|
|
value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
|
|||
|
|
v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
|||
|
|
s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
|||
|
|
softmax_scale: float,
|
|||
|
|
causal: bool = True,
|
|||
|
|
stage: str = "intra",
|
|||
|
|
block_size_M: int = 64,
|
|||
|
|
block_size_N: int = 64,
|
|||
|
|
vertical_indices_count: torch.Tensor = None, # [N_HEADS,]
|
|||
|
|
slash_indices_count: torch.Tensor = None,
|
|||
|
|
):
|
|||
|
|
if stage == "intra":
|
|||
|
|
assert causal
|
|||
|
|
else:
|
|||
|
|
assert not causal
|
|||
|
|
|
|||
|
|
batch_size, num_heads, context_size, head_dim = query.shape
|
|||
|
|
_, _, kv_seq_len, _ = key.shape
|
|||
|
|
|
|||
|
|
if head_dim not in [16, 32, 64, 128, 256, 512]:
|
|||
|
|
target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim
|
|||
|
|
query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
|||
|
|
key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
|||
|
|
value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
|
|||
|
|
|
|||
|
|
v_idx = v_idx.to(torch.int32).reshape(
|
|||
|
|
(batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
|
|||
|
|
s_idx = s_idx.to(torch.int32).reshape(
|
|||
|
|
(batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
|
|||
|
|
q_seqlens = torch.tensor([context_size],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=query.device)
|
|||
|
|
kv_seqlens = torch.tensor([kv_seq_len],
|
|||
|
|
dtype=torch.int32,
|
|||
|
|
device=query.device)
|
|||
|
|
|
|||
|
|
if vertical_indices_count is not None and slash_indices_count is not None:
|
|||
|
|
(
|
|||
|
|
block_count,
|
|||
|
|
block_offset,
|
|||
|
|
column_count,
|
|||
|
|
column_index,
|
|||
|
|
) = ops.convert_vertical_slash_indexes_mergehead(
|
|||
|
|
q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count,
|
|||
|
|
slash_indices_count, context_size, block_size_M, block_size_N,
|
|||
|
|
causal)
|
|||
|
|
else:
|
|||
|
|
(
|
|||
|
|
block_count,
|
|||
|
|
block_offset,
|
|||
|
|
column_count,
|
|||
|
|
column_index,
|
|||
|
|
) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx,
|
|||
|
|
s_idx, context_size,
|
|||
|
|
block_size_M, block_size_N,
|
|||
|
|
causal)
|
|||
|
|
|
|||
|
|
q = query.transpose(1, 2).contiguous()
|
|||
|
|
k = key.transpose(1, 2).contiguous()
|
|||
|
|
v = value.transpose(1, 2).contiguous()
|
|||
|
|
out, lse = sparse_attn_func(
|
|||
|
|
q,
|
|||
|
|
k,
|
|||
|
|
v,
|
|||
|
|
block_count,
|
|||
|
|
block_offset,
|
|||
|
|
column_count,
|
|||
|
|
column_index,
|
|||
|
|
causal=causal,
|
|||
|
|
softmax_scale=softmax_scale,
|
|||
|
|
return_softmax_lse=True,
|
|||
|
|
)
|
|||
|
|
out = out.transpose(1, 2).contiguous()
|
|||
|
|
softmax_lse = lse.reshape(*lse.shape, 1)
|
|||
|
|
return (out[..., :context_size, :head_dim],
|
|||
|
|
softmax_lse[..., :context_size, :])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _sum_all_diagonal_matrix(mat: torch.tensor):
|
|||
|
|
h, n, m = mat.shape
|
|||
|
|
# Zero matrix used for padding
|
|||
|
|
zero_mat = torch.zeros((h, n, n), device=mat.device)
|
|||
|
|
# pads the matrix on left and right
|
|||
|
|
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1)
|
|||
|
|
# Change the strides
|
|||
|
|
mat_strided = mat_padded.as_strided((1, n, n + m),
|
|||
|
|
(n * (2 * n + m), 2 * n + m + 1, 1))
|
|||
|
|
# Sums the resulting matrix's columns
|
|||
|
|
sum_diags = torch.sum(mat_strided, 1)
|
|||
|
|
return sum_diags[:, 1:] # drop left bottom corner
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _get_block(block_table: torch.Tensor, block_size: int, begin: int,
|
|||
|
|
end: int):
|
|||
|
|
begin_block = begin // block_size
|
|||
|
|
end_block = (end - 1) // block_size + 1
|
|||
|
|
return block_table[begin_block:end_block]
|