1051 lines
40 KiB
Python
1051 lines
40 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
"""Attention layer with FlashAttention."""
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Optional, ClassVar
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.attention.backends.abstract import (AttentionImpl,
|
|
AttentionMetadata, AttentionType,
|
|
is_quantized_kv_cache,
|
|
MultipleOf,)
|
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
|
from vllm.config.vllm import VllmConfig
|
|
from vllm.v1.worker.block_table import BlockTable
|
|
|
|
from vllm.v1.attention.backends.flash_attn import (
|
|
FlashAttentionBackend, FlashAttentionMetadata,
|
|
FlashAttentionMetadataBuilder,
|
|
_get_sliding_window_configs
|
|
)
|
|
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionCGSupport,
|
|
split_decodes_and_prefills,
|
|
)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm_mlu.v1.worker.gpu_model_runner import MLUModelRunner
|
|
|
|
if current_platform.is_cuda():
|
|
from vllm.attention.utils.fa_utils import get_scheduler_metadata
|
|
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
|
from vllm_mlu.v1.attention.backends.utils import (
|
|
MLUCommonAttentionMetadata,
|
|
MLUInferMode,
|
|
get_common_metadata,
|
|
)
|
|
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import attn_str_dtype_to_torch
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MLUFlashAttentionBackend(FlashAttentionBackend):
|
|
|
|
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [1, 16, 32, 64]
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [32, 64, 80, 96, 128, 160, 192, 224, 256, 512, 576]
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["MLUFlashAttentionImpl"]:
|
|
return MLUFlashAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
return MLUFlashAttentionMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["MLUFlashAttentionMetadataBuilder"]:
|
|
return MLUFlashAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
return (2, num_blocks, num_kv_heads, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def get_kv_cache_scale_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
) -> tuple[int, ...]:
|
|
return (2, num_blocks, num_kv_heads, block_size)
|
|
|
|
|
|
@dataclass
|
|
class MLUChunkFlashAttentionMetadata:
|
|
"""
|
|
Chunked prefill metadata for MLU backend, which splits both
|
|
input and metadata into prefill and decode phases. With splitting,
|
|
the MLU backend can invoke FA and single_query_cached_kv_attn kerels
|
|
seperately, thus yields better performance.
|
|
"""
|
|
@dataclass
|
|
class ChunkContextMetadata:
|
|
"""
|
|
ChunkContextMetadata for prefill chunks and decode tokens.
|
|
"""
|
|
batch_size: int
|
|
num_actual_tokens: int
|
|
cu_seqlens_q: torch.Tensor
|
|
cu_seqlens_kv: torch.Tensor
|
|
max_query_len: int
|
|
max_seq_len: int
|
|
total_seqlens: int = 0
|
|
|
|
prefill_ctx: ChunkContextMetadata
|
|
decode_ctx: ChunkContextMetadata
|
|
|
|
@classmethod
|
|
def build(
|
|
cls,
|
|
common_attn_metadata: MLUCommonAttentionMetadata,
|
|
uniform_decode_query_len: int = 1,
|
|
):
|
|
assert common_attn_metadata.infer_mode.is_chunked
|
|
(
|
|
num_decodes,
|
|
num_prefills,
|
|
num_decode_tokens,
|
|
num_prefill_tokens,
|
|
) = split_decodes_and_prefills(common_attn_metadata,
|
|
uniform_decode_query_len,
|
|
require_uniform=True)
|
|
# split cu_seqlens_q and cu_seqlens_kv
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
d_cu_seqlens_q = query_start_loc[:num_decodes + 1]
|
|
p_cu_seqlens_q = query_start_loc[num_decodes:] - query_start_loc[num_decodes]
|
|
seq_start_loc = common_attn_metadata.seq_start_loc
|
|
d_cu_seqlens_kv = seq_start_loc[:num_decodes + 1]
|
|
p_cu_seqlens_kv = seq_start_loc[num_decodes:] - seq_start_loc[num_decodes]
|
|
# compute max_query_len and max_seq_len after split
|
|
# NOTE: use cpu tensor to avoid d2h copy.
|
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
query_len_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
seq_len_cpu = common_attn_metadata.seq_lens_cpu
|
|
d_max_query_len = 0
|
|
d_max_seq_len = 0
|
|
p_max_query_len = 0
|
|
p_max_seq_len = 0
|
|
p_total_seqlens = 0
|
|
if num_decodes > 0:
|
|
d_max_query_len = query_len_cpu[:num_decodes].max().item()
|
|
d_max_seq_len = seq_len_cpu[:num_decodes].max().item()
|
|
if num_prefills > 0:
|
|
p_max_query_len = query_len_cpu[num_decodes:].max().item()
|
|
p_max_seq_len = seq_len_cpu[num_decodes:].max().item()
|
|
p_total_seqlens = seq_len_cpu[num_decodes:].sum().item()
|
|
|
|
return MLUChunkFlashAttentionMetadata(
|
|
prefill_ctx=MLUChunkFlashAttentionMetadata.
|
|
ChunkContextMetadata(
|
|
batch_size=num_prefills,
|
|
num_actual_tokens=num_prefill_tokens,
|
|
cu_seqlens_q=p_cu_seqlens_q,
|
|
cu_seqlens_kv=p_cu_seqlens_kv,
|
|
max_query_len=p_max_query_len,
|
|
max_seq_len=p_max_seq_len,
|
|
total_seqlens=p_total_seqlens,
|
|
),
|
|
decode_ctx=MLUChunkFlashAttentionMetadata.
|
|
ChunkContextMetadata(
|
|
batch_size=num_decodes,
|
|
num_actual_tokens=num_decode_tokens,
|
|
cu_seqlens_q=d_cu_seqlens_q,
|
|
cu_seqlens_kv=d_cu_seqlens_kv,
|
|
max_query_len=d_max_query_len,
|
|
max_seq_len=d_max_seq_len,
|
|
),
|
|
)
|
|
|
|
@dataclass
|
|
class MLUFlashAttentionMetadata(FlashAttentionMetadata):
|
|
# For mlu infer
|
|
seq_start_loc: torch.Tensor | None = None
|
|
infer_mode: MLUInferMode | None = None
|
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
compute_dtype: torch.dtype = torch.float32
|
|
chunk_fa_metadata: MLUChunkFlashAttentionMetadata | None = None
|
|
|
|
@property
|
|
def num_decode_tokens(self):
|
|
assert self.infer_mode is not None, (
|
|
f"MLUFlashAttentionMetadata infer_mode is not set."
|
|
)
|
|
|
|
if self.infer_mode == MLUInferMode.PREFILL_ONLY:
|
|
return 0
|
|
|
|
if self.infer_mode == MLUInferMode.DECODE_ONLY:
|
|
return self.num_actual_tokens
|
|
|
|
assert self.chunk_fa_metadata is not None, (
|
|
f"chunk_fa_metadata must be set under chunked infer mode."
|
|
)
|
|
return self.chunk_fa_metadata.decode_ctx.num_actual_tokens
|
|
|
|
|
|
def pad_attn_metadata(
|
|
attn_metadata: MLACommonMetadata | FlashAttentionMetadata,
|
|
common_metadata: MLUCommonAttentionMetadata,
|
|
block_table: BlockTable,
|
|
runner: "MLUModelRunner",
|
|
num_scheduled_tokens: int,
|
|
num_input_tokens: int,
|
|
num_reqs: int,
|
|
num_paded_reqs: int,
|
|
) -> None:
|
|
is_mla = isinstance(attn_metadata, MLACommonMetadata)
|
|
if is_mla:
|
|
assert attn_metadata.prefill is None and attn_metadata.decode is not None
|
|
|
|
pad_token_num = num_input_tokens - num_scheduled_tokens
|
|
pad_req_num = num_paded_reqs - num_reqs
|
|
if pad_token_num == 0:
|
|
return
|
|
|
|
query_start_loc_cpu = runner.query_start_loc.cpu[:num_paded_reqs + 1]
|
|
query_start_loc = runner.query_start_loc.gpu[:num_paded_reqs + 1]
|
|
seq_lens_cpu = runner.seq_lens.cpu[:num_paded_reqs]
|
|
seq_lens = runner.seq_lens.gpu[:num_paded_reqs]
|
|
if pad_req_num > 0:
|
|
query_lens = torch.diff(query_start_loc_cpu[:num_reqs + 1])
|
|
pad_lens = torch.full(
|
|
(pad_req_num,),
|
|
pad_token_num // pad_req_num,
|
|
dtype=query_lens.dtype,
|
|
device=query_lens.device)
|
|
query_lens = torch.cat([query_lens, pad_lens])
|
|
torch.cumsum(query_lens, dim=0, out=query_start_loc_cpu[1:])
|
|
query_start_loc.copy_(query_start_loc_cpu, non_blocking=True)
|
|
seq_lens_cpu[num_reqs:].fill_(common_metadata.max_query_len)
|
|
seq_lens[num_reqs:].fill_(common_metadata.max_query_len)
|
|
|
|
seq_start_loc_cpu = runner.seq_start_loc.cpu[:(num_paded_reqs + 1)]
|
|
seq_start_loc = runner.seq_start_loc.gpu[:(num_paded_reqs + 1)]
|
|
torch.cumsum(seq_lens, dim=0, out=seq_start_loc[1:])
|
|
torch.cumsum(seq_lens_cpu, dim=0, out=seq_start_loc_cpu[1:])
|
|
|
|
slot_mapping_org_num = attn_metadata.slot_mapping.numel()
|
|
slot_mapping = block_table.slot_mapping.gpu[:(slot_mapping_org_num + pad_token_num)]
|
|
slot_mapping[slot_mapping_org_num:] = PAD_SLOT_ID
|
|
|
|
block_table = block_table.get_device_tensor(num_paded_reqs)
|
|
|
|
attn_metadata.slot_mapping = slot_mapping
|
|
attn_metadata.query_start_loc = query_start_loc
|
|
if is_mla:
|
|
attn_metadata.decode.query_start_loc = query_start_loc
|
|
attn_metadata.decode.seq_lens = seq_lens
|
|
attn_metadata.decode.block_table = block_table
|
|
else:
|
|
attn_metadata.seq_lens = seq_lens
|
|
attn_metadata.seq_start_loc = seq_start_loc
|
|
attn_metadata.block_table = block_table
|
|
|
|
common_metadata.num_input_tokens = num_input_tokens
|
|
common_metadata.seq_start_loc = seq_start_loc
|
|
common_metadata.seq_start_loc_cpu = seq_start_loc_cpu
|
|
common_metadata.query_start_loc = query_start_loc
|
|
common_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
common_metadata.seq_lens = seq_lens
|
|
common_metadata.seq_lens_cpu = seq_lens_cpu
|
|
common_metadata.num_reqs = num_paded_reqs
|
|
common_metadata.block_table_tensor = block_table
|
|
common_metadata.slot_mapping = slot_mapping
|
|
|
|
|
|
class MLUFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
|
cudagraph_support = (
|
|
AttentionCGSupport.UNIFORM_BATCH
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add class member - uniform_decode_query_len
|
|
'''
|
|
self.uniform_decode_query_len = (
|
|
1 if not self.vllm_config.speculative_config
|
|
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: MLUCommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> MLUFlashAttentionMetadata:
|
|
"""
|
|
fast_build disables AOT scheduling, used when there will be few
|
|
iterations i.e. spec-decode
|
|
"""
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
max_seq_len = common_attn_metadata.max_seq_len
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
|
slot_mapping = common_attn_metadata.slot_mapping
|
|
causal = common_attn_metadata.causal
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add seq_start_loc for chunk fa
|
|
'''
|
|
seq_start_loc = common_attn_metadata.seq_start_loc
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
# the overhead of the aot schedule is not worth it for spec-decode
|
|
aot_schedule = self.aot_schedule and not fast_build
|
|
|
|
if self.aot_sliding_window is None:
|
|
self.aot_sliding_window = (-1, -1)
|
|
# For the AOT scheduler we need the sliding window value to be
|
|
# constant for all layers to. We have to populate this on the first
|
|
# build() call so the layers are constructed (cannot populate)
|
|
# in __init__.
|
|
if aot_schedule:
|
|
sliding_window_configs = _get_sliding_window_configs(self.vllm_config)
|
|
if len(sliding_window_configs) == 1:
|
|
sliding_window_config = sliding_window_configs.pop()
|
|
if sliding_window_config is not None:
|
|
self.aot_sliding_window = sliding_window_config
|
|
elif len(sliding_window_configs) > 1:
|
|
self.aot_schedule = False
|
|
aot_schedule = False
|
|
|
|
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
|
|
if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size:
|
|
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
|
# usage, because the intermediate buffers of size [num_splits,
|
|
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
|
# we only set num_splits when using cuda graphs.
|
|
max_num_splits = self.max_num_splits
|
|
|
|
if vllm_is_batch_invariant():
|
|
max_num_splits = 1
|
|
|
|
def schedule(
|
|
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
|
):
|
|
cache_dtype = self.cache_config.cache_dtype
|
|
if cache_dtype.startswith("fp8"):
|
|
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
|
|
cache_dtype
|
|
)
|
|
else:
|
|
qkv_dtype = self.kv_cache_dtype
|
|
if aot_schedule:
|
|
return get_scheduler_metadata(
|
|
batch_size=batch_size,
|
|
max_seqlen_q=max_query_len,
|
|
max_seqlen_k=max_seq_len,
|
|
num_heads_q=self.num_heads_q * self.dcp_world_size,
|
|
num_heads_kv=self.num_heads_kv,
|
|
headdim=self.headdim,
|
|
cache_seqlens=seqlens,
|
|
qkv_dtype=qkv_dtype,
|
|
cu_seqlens_q=cu_query_lens,
|
|
page_size=self.block_size,
|
|
causal=causal,
|
|
window_size=self.aot_sliding_window,
|
|
num_splits=max_num_splits,
|
|
)
|
|
return None
|
|
|
|
use_cascade = common_prefix_len > 0
|
|
max_dcp_context_kv_len = 0
|
|
dcp_context_kv_lens = None
|
|
|
|
cu_prefix_query_lens = None
|
|
prefix_kv_lens = None
|
|
suffix_kv_lens = None
|
|
prefix_scheduler_metadata = None
|
|
|
|
if self.dcp_world_size > 1:
|
|
query_kv_lens_cpu = (
|
|
common_attn_metadata.query_start_loc_cpu[1:]
|
|
- common_attn_metadata.query_start_loc_cpu[:-1]
|
|
)
|
|
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
|
|
|
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
|
|
dcp_context_kv_lens_cpu,
|
|
self.dcp_world_size,
|
|
self.dcp_rank,
|
|
self.dcp_kv_cache_interleave_size,
|
|
)
|
|
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
|
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
|
|
|
scheduler_metadata = schedule(
|
|
batch_size=num_reqs,
|
|
cu_query_lens=query_start_loc,
|
|
max_query_len=max_query_len,
|
|
seqlens=dcp_context_kv_lens,
|
|
max_seq_len=max_dcp_context_kv_len,
|
|
causal=False,
|
|
)
|
|
elif use_cascade:
|
|
cu_prefix_query_lens = torch.tensor(
|
|
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
|
)
|
|
prefix_kv_lens = torch.tensor(
|
|
[common_prefix_len], dtype=torch.int32, device=self.device
|
|
)
|
|
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
|
|
self.device, non_blocking=True
|
|
)
|
|
prefix_scheduler_metadata = schedule(
|
|
batch_size=1,
|
|
cu_query_lens=cu_prefix_query_lens,
|
|
max_query_len=num_actual_tokens,
|
|
seqlens=prefix_kv_lens,
|
|
max_seq_len=common_prefix_len,
|
|
causal=False,
|
|
)
|
|
scheduler_metadata = schedule(
|
|
batch_size=num_reqs,
|
|
cu_query_lens=query_start_loc,
|
|
max_query_len=max_query_len,
|
|
seqlens=suffix_kv_lens,
|
|
max_seq_len=max_seq_len - common_prefix_len,
|
|
causal=True,
|
|
)
|
|
else:
|
|
scheduler_metadata = schedule(
|
|
batch_size=num_reqs,
|
|
cu_query_lens=query_start_loc,
|
|
max_query_len=max_query_len,
|
|
seqlens=seq_lens,
|
|
max_seq_len=max_seq_len,
|
|
causal=causal,
|
|
)
|
|
# For FA3 + full cudagraph
|
|
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
|
n = scheduler_metadata.shape[0]
|
|
self.scheduler_metadata[:n] = scheduler_metadata
|
|
# NOTE(woosuk): We should zero out the rest of the scheduler
|
|
# metadata to guarantee the correctness. Otherwise, some thread
|
|
# blocks may use the invalid scheduler metadata and overwrite the
|
|
# output buffer.
|
|
self.scheduler_metadata[n:] = 0
|
|
scheduler_metadata = self.scheduler_metadata[:n]
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: 1. build MLUChunkFlashAttentionMetadata to split prefill and decode;
|
|
2. replace metadata with MLUFlashAttnetionMetadta.
|
|
'''
|
|
chunk_fa_metadata = None
|
|
if common_attn_metadata.infer_mode.is_chunked:
|
|
chunk_fa_metadata = MLUChunkFlashAttentionMetadata.build(
|
|
common_attn_metadata,
|
|
self.uniform_decode_query_len,
|
|
)
|
|
|
|
attn_metadata = MLUFlashAttentionMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
query_start_loc=query_start_loc,
|
|
max_seq_len=max_seq_len,
|
|
seq_lens=seq_lens,
|
|
block_table=block_table_tensor,
|
|
slot_mapping=slot_mapping,
|
|
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
|
dcp_context_kv_lens=dcp_context_kv_lens,
|
|
use_cascade=use_cascade,
|
|
common_prefix_len=common_prefix_len,
|
|
scheduler_metadata=scheduler_metadata,
|
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
|
prefix_kv_lens=prefix_kv_lens,
|
|
suffix_kv_lens=suffix_kv_lens,
|
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
|
max_num_splits=max_num_splits,
|
|
causal=causal,
|
|
# For mlu infer
|
|
seq_start_loc=common_attn_metadata.seq_start_loc,
|
|
infer_mode=common_attn_metadata.infer_mode,
|
|
chunk_fa_metadata=chunk_fa_metadata,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
return attn_metadata
|
|
|
|
|
|
class MLUFlashAttentionImpl(AttentionImpl):
|
|
can_return_lse_for_decode: bool = True
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None = None,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: str | None = None,
|
|
sinks: torch.Tensor | None = None,
|
|
**extra_impl_args,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: 1. move alibi_slopes to mlu,
|
|
2. sliding_window_right only support -1.
|
|
3. add self.use_fused_mla_qkv.
|
|
'''
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32).mlu()
|
|
self.alibi_slopes = alibi_slopes
|
|
if sliding_window is None:
|
|
self.sliding_window = (-1, -1)
|
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
|
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
|
else:
|
|
self.sliding_window = (sliding_window - 1, 0)
|
|
self.is_mla = extra_impl_args.get("is_mla", False)
|
|
self.use_fused_mla_qkv = extra_impl_args.get("use_fused_mla_qkv", False)
|
|
|
|
self.decoder_attn_dtype = extra_impl_args.get("decoder_attn_dtype", None)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
if logits_soft_cap is None:
|
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
|
logits_soft_cap = 0
|
|
self.logits_soft_cap = logits_soft_cap
|
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
self.attn_type = attn_type
|
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
|
# Cache the batch invariant result for use in forward passes
|
|
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
|
|
|
self.sinks = sinks
|
|
if self.sinks is not None:
|
|
assert flash_attn_supports_sinks(), (
|
|
"Sinks are only supported in FlashAttention 3"
|
|
)
|
|
assert self.sinks.shape[0] == num_heads, (
|
|
"Sinks must have the same number of heads as the number of "
|
|
"heads in the layer"
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: MLUFlashAttentionMetadata,
|
|
output: torch.Tensor | None = None,
|
|
output_scale: torch.Tensor | None = None,
|
|
output_block_scale: torch.Tensor | None = None,
|
|
kwargs: dict[str, Any] = {},
|
|
) -> torch.Tensor:
|
|
"""Forward pass with FlashAttention.
|
|
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache: shape =
|
|
[2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
NOTE: FP8 quantization, flash-attn expect the size of
|
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
|
We use torch's .expand() to avoid duplicating values
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if output_scale is not None or output_block_scale is not None:
|
|
raise NotImplementedError(
|
|
"fused output quantization is not yet supported for FlashAttentionImpl"
|
|
)
|
|
|
|
if attn_metadata is None:
|
|
# Profiling run.
|
|
return output.fill_(0)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: set mlu infer mode.
|
|
'''
|
|
infer_mode = attn_metadata.infer_mode
|
|
assert not attn_metadata.use_cascade, (
|
|
f"MLU not support use_cascade={attn_metadata.use_cascade}, " +
|
|
f"attn_metadata={attn_metadata}."
|
|
)
|
|
assert self.dcp_world_size <= 1, (
|
|
f"MLU not support dcp_world_size={self.dcp_world_size}."
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
attn_type = self.attn_type
|
|
|
|
# IMPORTANT!
|
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
|
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
|
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
|
# Minimize the PyTorch ops in this method as much as possible.
|
|
# Whenever making a change in this method, please benchmark the
|
|
# performance to make sure it does not introduce any overhead.
|
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
|
|
# Handle encoder attention differently - no KV cache needed
|
|
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
|
# For encoder attention,
|
|
# we use direct Q, K, V tensors without caching
|
|
return self._forward_encoder_attention(
|
|
query[:num_actual_tokens],
|
|
key[:num_actual_tokens],
|
|
value[:num_actual_tokens],
|
|
output[:num_actual_tokens],
|
|
attn_metadata,
|
|
layer,
|
|
)
|
|
|
|
# For decoder and cross-attention, use KV cache as before
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: kv_cache[0] is [key_cache, value_cache], and
|
|
kv_cache[1] is [key_cache_scale, value_cache_scale].
|
|
'''
|
|
key_cache, value_cache = kv_cache[0].unbind(0)
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
key_cache_scale, value_cache_scale = kv_cache[1].unbind(0)
|
|
else:
|
|
key_cache_scale = None
|
|
value_cache_scale = None
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
# key and value may be None in the case of cross attention. They are
|
|
# calculated once based on the output from the encoder and then cached
|
|
# in KV cache.
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: skip store key/value to kv cache in mla prefill phase.
|
|
@brief: support value is None.
|
|
'''
|
|
skip_process_cache = (
|
|
self.is_mla
|
|
and (infer_mode.is_prefill_only or self.use_fused_mla_qkv)
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
if (
|
|
self.kv_sharing_target_layer_name is None
|
|
and key is not None
|
|
and value is not None
|
|
and not skip_process_cache
|
|
):
|
|
# Reshape the input keys and values and store them in the cache.
|
|
# Skip this if sharing KV cache with an earlier attention layer.
|
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
|
# op uses the slot_mapping's shape to determine the number of
|
|
# actual tokens.
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: store key/value cache with mlu ops.
|
|
'''
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
mlu_ops.quant_to_paged_cache(
|
|
k=key[:num_actual_tokens],
|
|
v=(None if self.is_mla else value[:num_actual_tokens]),
|
|
k_cache=key_cache,
|
|
v_cache=value_cache,
|
|
k_cache_quant_scale=key_cache_scale,
|
|
v_cache_quant_scale=value_cache_scale,
|
|
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
|
)
|
|
else:
|
|
mlu_ops.reshape_paged_cache(
|
|
k=key[:num_actual_tokens],
|
|
v=(None if self.is_mla else value[:num_actual_tokens]),
|
|
k_cache=key_cache,
|
|
v_cache=value_cache,
|
|
slot_mapping=attn_metadata.slot_mapping.flatten(),
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: skip cascade attention for mlu platform.
|
|
'''
|
|
if attn_metadata.use_cascade:
|
|
raise RuntimeError(
|
|
f"mlu v1 not support use_cascade={attn_metadata.use_cascade}, " +
|
|
f"attn_metadata={attn_metadata}."
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
cu_seqlens_q = attn_metadata.query_start_loc
|
|
cu_seqlens_kv = attn_metadata.seq_start_loc
|
|
seqused_k = attn_metadata.seq_lens
|
|
max_seqlen_q = attn_metadata.max_query_len
|
|
max_seqlen_k = attn_metadata.max_seq_len
|
|
block_table = attn_metadata.block_table
|
|
|
|
alibi_slopes = (
|
|
None if self.alibi_slopes is None
|
|
else self.alibi_slopes.repeat(seqused_k.shape[0], 1)
|
|
)
|
|
head_size_v = value.shape[-1] if self.is_mla else self.head_size
|
|
q_quant_scale = kwargs.get("q_quant_scale", None)
|
|
|
|
if infer_mode.is_prefill_only:
|
|
num_prefill_query_tokens = num_actual_tokens
|
|
num_prefill_kv_tokens = num_actual_tokens
|
|
mlu_ops.flash_attention(
|
|
q=query[:num_prefill_query_tokens],
|
|
k=key[:num_prefill_kv_tokens],
|
|
v=value[:num_prefill_kv_tokens],
|
|
out=output[:num_prefill_query_tokens],
|
|
cu_seq_lens_q=cu_seqlens_q,
|
|
cu_seq_lens_kv=cu_seqlens_kv,
|
|
alibi_slope=alibi_slopes,
|
|
attn_bias=None,
|
|
max_seq_len_q=max_seqlen_q,
|
|
max_seq_len_kv=max_seqlen_k,
|
|
softmax_scale=self.scale,
|
|
is_causal=True,
|
|
window_size_left=self.sliding_window[0],
|
|
window_size_right=self.sliding_window[1],
|
|
compute_dtype=attn_metadata.compute_dtype,
|
|
return_lse=False,
|
|
)
|
|
elif infer_mode.is_chunked:
|
|
# prefill & decode mixed
|
|
# NOTE: Split prefill chunks and decode tokens will
|
|
# get better performance on MLU devices.
|
|
chunk_fa_metadata = attn_metadata.chunk_fa_metadata
|
|
prefill_ctx = chunk_fa_metadata.prefill_ctx
|
|
decode_ctx = chunk_fa_metadata.decode_ctx
|
|
num_decodes = decode_ctx.batch_size
|
|
num_decode_tokens = decode_ctx.num_actual_tokens
|
|
num_prefills = prefill_ctx.batch_size
|
|
if num_prefills > 0:
|
|
self._forward_prefill_chunk(
|
|
query=query[num_decode_tokens:],
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
output=output[num_decode_tokens:],
|
|
block_table=block_table[num_decodes:],
|
|
seqused_k=seqused_k[num_decodes:],
|
|
compute_dtype=attn_metadata.compute_dtype,
|
|
prefill_ctx=prefill_ctx,
|
|
alibi_slopes=alibi_slopes,
|
|
key_cache_scale=key_cache_scale,
|
|
value_cache_scale=value_cache_scale,
|
|
)
|
|
if num_decodes > 0:
|
|
if q_quant_scale is not None:
|
|
q_quant_scale = q_quant_scale[:num_decode_tokens]
|
|
self._forward_decode_only(
|
|
query=query[:num_decode_tokens],
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
output=output[:num_decode_tokens],
|
|
block_table=block_table[:num_decodes],
|
|
seqused_k=seqused_k[:num_decodes],
|
|
max_seqlen_k=decode_ctx.max_seq_len,
|
|
head_size_v=head_size_v,
|
|
compute_dtype=attn_metadata.compute_dtype,
|
|
alibi_slopes=alibi_slopes,
|
|
key_cache_scale=key_cache_scale,
|
|
value_cache_scale=value_cache_scale,
|
|
q_quant_scale=q_quant_scale,
|
|
)
|
|
else:
|
|
# decode only
|
|
if q_quant_scale is not None:
|
|
q_quant_scale = q_quant_scale[:num_actual_tokens]
|
|
self._forward_decode_only(
|
|
query=query[:num_actual_tokens],
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
output=output[:num_actual_tokens],
|
|
block_table=block_table,
|
|
seqused_k=seqused_k,
|
|
max_seqlen_k=max_seqlen_k,
|
|
head_size_v=head_size_v,
|
|
compute_dtype=attn_metadata.compute_dtype,
|
|
alibi_slopes=alibi_slopes,
|
|
key_cache_scale=key_cache_scale,
|
|
value_cache_scale=value_cache_scale,
|
|
q_quant_scale=q_quant_scale,
|
|
)
|
|
return output
|
|
|
|
def _forward_prefill_chunk(
|
|
self,
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
output: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
seqused_k: torch.Tensor,
|
|
compute_dtype: torch.dtype,
|
|
prefill_ctx: MLUChunkFlashAttentionMetadata.ChunkContextMetadata,
|
|
alibi_slopes: torch.Tensor | None = None,
|
|
key_cache_scale: torch.Tensor | None = None,
|
|
value_cache_scale: torch.Tensor | None = None,
|
|
):
|
|
'''
|
|
Compute prefill chunks when enable chunked_prefill.
|
|
NOTE: If the kv_cache is quantized,
|
|
will first be dequantized, and return continuous key and value.
|
|
'''
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
total_seqlens = prefill_ctx.total_seqlens
|
|
key_cache_dequant = torch.zeros(
|
|
(total_seqlens, self.num_kv_heads, self.head_size),
|
|
dtype=query.dtype,
|
|
device=key_cache.device
|
|
)
|
|
value_cache_dequant = None
|
|
if value_cache is not None:
|
|
value_cache_dequant = torch.zeros(
|
|
(total_seqlens, self.num_kv_heads, self.head_size),
|
|
dtype=query.dtype,
|
|
device=key_cache.device
|
|
)
|
|
mlu_ops.dequant_from_paged_cache(
|
|
key=key_cache_dequant,
|
|
value=value_cache_dequant,
|
|
key_cache=key_cache,
|
|
value_cache=value_cache,
|
|
key_cache_quant_scale=key_cache_scale,
|
|
value_cache_quant_scale=value_cache_scale,
|
|
context_lengths=seqused_k,
|
|
max_context_len=prefill_ctx.max_seq_len,
|
|
context_seq_offset=None,
|
|
block_tables=block_table,
|
|
quant_mode=1,
|
|
quant_bit=8
|
|
)
|
|
block_table_dequant = None
|
|
else:
|
|
key_cache_dequant = key_cache
|
|
value_cache_dequant = value_cache
|
|
block_table_dequant = block_table
|
|
mlu_ops.flash_attention(
|
|
q=query,
|
|
k=key_cache_dequant,
|
|
v=value_cache_dequant,
|
|
out=output,
|
|
cu_seq_lens_q=prefill_ctx.cu_seqlens_q,
|
|
cu_seq_lens_kv=prefill_ctx.cu_seqlens_kv,
|
|
alibi_slope=alibi_slopes,
|
|
attn_bias=None,
|
|
max_seq_len_q=prefill_ctx.max_query_len,
|
|
max_seq_len_kv=prefill_ctx.max_seq_len,
|
|
softmax_scale=self.scale,
|
|
is_causal=True,
|
|
window_size_left=self.sliding_window[0],
|
|
window_size_right=self.sliding_window[1],
|
|
compute_dtype=compute_dtype,
|
|
return_lse=False,
|
|
block_tables=block_table_dequant,
|
|
)
|
|
|
|
def _forward_decode_only(
|
|
self,
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
output: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
seqused_k: torch.Tensor,
|
|
max_seqlen_k: int,
|
|
head_size_v: int,
|
|
compute_dtype: torch.dtype,
|
|
alibi_slopes: torch.Tensor | None = None,
|
|
key_cache_scale: torch.Tensor | None = None,
|
|
value_cache_scale: torch.Tensor | None = None,
|
|
q_quant_scale: torch.Tensor | None = None,
|
|
):
|
|
'''
|
|
Compute decode tokens only.
|
|
NOTE: Query only support pad mode, be careful when using MTP model.
|
|
'''
|
|
batch_size = block_table.shape[0]
|
|
decode_query = query.view(batch_size, -1, self.num_heads, self.head_size)
|
|
decode_output = output.view(batch_size, -1, self.num_heads, head_size_v)
|
|
if q_quant_scale is not None:
|
|
q_quant_scale = q_quant_scale.view(batch_size, -1, self.num_heads)
|
|
mlu_ops.single_query_cached_kv_attn(
|
|
q=decode_query,
|
|
k_cache=key_cache,
|
|
v_cache=value_cache,
|
|
out=decode_output,
|
|
block_tables=block_table,
|
|
context_lens=seqused_k,
|
|
k_cache_quant_scale=key_cache_scale,
|
|
v_cache_quant_scale=value_cache_scale,
|
|
alibi_slopes=alibi_slopes,
|
|
max_contxt_len=max_seqlen_k,
|
|
windows_size_left=self.sliding_window[0],
|
|
windows_size_right=self.sliding_window[1],
|
|
softmax_scale=self.scale,
|
|
head_size_v=(-1 if not self.is_mla else head_size_v),
|
|
compute_dtype=compute_dtype,
|
|
q_quant_scale=q_quant_scale,
|
|
decoder_attn_dtype=self.decoder_attn_dtype,
|
|
)
|
|
|
|
def _forward_encoder_attention(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
output: torch.Tensor,
|
|
attn_metadata: FlashAttentionMetadata,
|
|
layer: torch.nn.Module,
|
|
) -> torch.Tensor:
|
|
"""Forward pass for encoder attention without KV cache.
|
|
|
|
Args:
|
|
query: shape = [num_encoder_tokens, num_heads, head_size]
|
|
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
|
|
output: shape = [num_encoder_tokens, num_heads, head_size]
|
|
attn_metadata: Encoder attention metadata
|
|
layer: The attention layer
|
|
"""
|
|
# For encoder attention, process FP8 quantization if needed
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
raise NotImplementedError(
|
|
"quantization is not supported for encoder attention"
|
|
)
|
|
|
|
# Use encoder-specific metadata for sequence information
|
|
cu_seqlens_q = attn_metadata.query_start_loc
|
|
cu_seqlens_k = attn_metadata.query_start_loc
|
|
max_seqlen_q = attn_metadata.max_query_len
|
|
max_seqlen_k = attn_metadata.max_query_len
|
|
|
|
# Call flash attention directly on Q, K, V tensors
|
|
mlu_ops.flash_attention(
|
|
q=query,
|
|
k=key,
|
|
v=value,
|
|
out=output,
|
|
cu_seq_lens_q=cu_seqlens_q,
|
|
cu_seq_lens_kv=cu_seqlens_k,
|
|
alibi_slope=None,
|
|
attn_bias=None,
|
|
max_seq_len_q=max_seqlen_q,
|
|
max_seq_len_kv=max_seqlen_k,
|
|
softmax_scale=self.scale,
|
|
is_causal=False, # Encoder attention is bidirectional
|
|
window_size_left=self.sliding_window[0],
|
|
window_size_right=self.sliding_window[1],
|
|
compute_dtype=attn_metadata.compute_dtype,
|
|
return_lse=False,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
|