forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
0
vllm/v1/attention/backends/__init__.py
Normal file
184
vllm/v1/attention/backends/cpu_attn.py
Normal file
184
vllm/v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
|
||||
TorchSDPAMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
accept_output_buffer: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return PagedAttention.get_supported_head_sizes()
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
|
||||
return TorchSDPAMetadataBuilderV1
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable) -> None:
|
||||
self.runner = runner
|
||||
self.block_table = block_table
|
||||
|
||||
# For reorder
|
||||
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
|
||||
dtype=np.int64)
|
||||
self.num_prompt_req: int = 0
|
||||
|
||||
self.seq_start_loc_cpu = torch.zeros(
|
||||
runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
prompt_list_idx = 0
|
||||
decode_list_idx = 0
|
||||
for req_index in range(input_batch.num_reqs):
|
||||
if input_batch.num_computed_tokens_cpu[
|
||||
req_index] < input_batch.num_prompt_tokens[req_index]:
|
||||
# prompt stage
|
||||
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
|
||||
prompt_list_idx += 1
|
||||
else:
|
||||
# decode stage
|
||||
self.reorder_decode_req_index_list[decode_list_idx] = req_index
|
||||
decode_list_idx += 1
|
||||
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
|
||||
|
||||
# Update prompt requests number
|
||||
self.num_prompt_req = prompt_list_idx
|
||||
|
||||
reorder_req_num = 0
|
||||
for req_index in range(decode_list_idx):
|
||||
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
|
||||
reorder_req_num += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if reorder_req_num == 0:
|
||||
return False
|
||||
|
||||
reorder_prompt_list = (
|
||||
self.reorder_prompt_req_index_list[:prompt_list_idx]
|
||||
[-reorder_req_num:])
|
||||
reorder_decode_list = (
|
||||
self.reorder_decode_req_index_list[:decode_list_idx]
|
||||
[:reorder_req_num])
|
||||
assert reorder_decode_list.size == reorder_prompt_list.size
|
||||
|
||||
for idx in range(reorder_req_num):
|
||||
prompt_req_index = reorder_prompt_list[idx].item()
|
||||
decode_req_index = reorder_decode_list[idx].item()
|
||||
input_batch.swap_states(prompt_req_index, decode_req_index)
|
||||
|
||||
return True
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
runner = self.runner
|
||||
block_table = self.block_table
|
||||
seq_lens_np = runner.seq_lens_np[:num_reqs]
|
||||
num_prompt_req = self.num_prompt_req
|
||||
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
||||
) if num_prompt_req > 0 else 0
|
||||
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
||||
) if num_prompt_req < num_reqs else 0
|
||||
self.seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
||||
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
|
||||
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
|
||||
) - num_prefill_tokens
|
||||
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
|
||||
block_table_tensor = block_table.get_device_tensor()
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
num_prefills=num_prompt_req,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens_tensor=runner.
|
||||
seq_lens_cpu[num_prompt_req:num_reqs], # decode
|
||||
max_decode_seq_len=max_decode_seq_len, # decode
|
||||
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
||||
chunked_prefill=True,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_prefill_seq_len,
|
||||
prefill_query_start_loc=runner.
|
||||
query_start_loc_cpu[:num_prompt_req + 1], # prefill
|
||||
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
||||
1], # prefill
|
||||
prefill_block_tables=block_table_tensor[:
|
||||
num_prompt_req], # prefill
|
||||
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
|
||||
1], # for logits index
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
935
vllm/v1/attention/backends/flash_attn.py
Normal file
935
vllm/v1/attention/backends/flash_attn.py
Normal file
@@ -0,0 +1,935 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version,
|
||||
is_flash_attn_varlen_func_available)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if is_flash_attn_varlen_func_available():
|
||||
if not current_platform.is_rocm():
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash)
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
|
||||
reshape_and_cache_cuda)
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
|
||||
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
else:
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (
|
||||
(num_blocks, num_kv_heads, block_size, head_size),
|
||||
(num_blocks, num_kv_heads, head_size, block_size),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
key_stride_order = (0, 1, 2, 3)
|
||||
value_stride_order = (0, 1, 2, 3)
|
||||
elif cache_layout == "HND":
|
||||
key_stride_order = (0, 2, 1, 3)
|
||||
value_stride_order = (0, 2, 1, 3)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return key_stride_order, value_stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
def _get_sliding_window_configs(
|
||||
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
|
||||
"""Get the set of all sliding window configs used in the model."""
|
||||
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
for layer in layers.values():
|
||||
assert isinstance(layer.impl, FlashAttentionImpl)
|
||||
sliding_window_configs.add(layer.impl.sliding_window)
|
||||
return sliding_window_configs
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 or current_platform.is_rocm()
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
compilation_config = runner.vllm_config.compilation_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.aot_schedule = (get_flash_attn_version() == 3)
|
||||
self.use_full_cuda_graph = compilation_config.full_cuda_graph
|
||||
if self.use_full_cuda_graph:
|
||||
if not current_platform.is_rocm():
|
||||
if not self.aot_schedule:
|
||||
raise ValueError(
|
||||
"AoT scheduling is required for full cuda graph.")
|
||||
capture_sizes = compilation_config.cudagraph_capture_sizes
|
||||
if not capture_sizes:
|
||||
raise ValueError(
|
||||
"cudagraph_capture_sizes should not be None when "
|
||||
"full_cuda_graph is True.")
|
||||
self.max_cudagraph_size = max(capture_sizes)
|
||||
if self.max_cudagraph_size > 992:
|
||||
# This condition derives from FA3's internal heuristic.
|
||||
# TODO(woosuk): Support larger cudagraph sizes.
|
||||
raise ValueError(
|
||||
"Capture size larger than 992 is not supported for "
|
||||
"full cuda graph.")
|
||||
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
self.runner.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlashAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if self.aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.runner.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
|
||||
if self.aot_sliding_window is None:
|
||||
self.aot_sliding_window = (-1, -1)
|
||||
# For the AOT scheduler we need the sliding window value to be
|
||||
# constant for all layers to. We have to populate this on the first
|
||||
# build() call so the layers are constructed (cannot populate)
|
||||
# in __init__.
|
||||
if self.aot_schedule:
|
||||
sliding_window_configs = _get_sliding_window_configs(
|
||||
self.runner.vllm_config)
|
||||
if len(sliding_window_configs) == 1:
|
||||
sliding_window_config = sliding_window_configs.pop()
|
||||
if sliding_window_config is not None:
|
||||
self.aot_sliding_window = sliding_window_config
|
||||
elif len(sliding_window_configs) > 1:
|
||||
self.aot_schedule = False
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=batch_size,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
cache_seqlens=seqlens,
|
||||
num_heads_q=self.num_heads_q,
|
||||
num_heads_kv=self.num_heads_kv,
|
||||
headdim=self.headdim,
|
||||
page_size=self.block_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
window_size=self.aot_sliding_window,
|
||||
num_splits=self.max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
prefix_scheduler_metadata = schedule(
|
||||
batch_size=1,
|
||||
cu_query_lens=cu_prefix_query_lens,
|
||||
max_query_len=num_actual_tokens,
|
||||
seqlens=prefix_kv_lens,
|
||||
max_seq_len=common_prefix_len,
|
||||
causal=False)
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=suffix_kv_lens,
|
||||
max_seq_len=max_seq_len -
|
||||
common_prefix_len,
|
||||
causal=True)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
scheduler_metadata = schedule(batch_size=num_reqs,
|
||||
cu_query_lens=query_start_loc,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True)
|
||||
|
||||
if not current_platform.is_rocm() and self.use_full_cuda_graph:
|
||||
assert scheduler_metadata is not None
|
||||
n = scheduler_metadata.shape[0]
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
max_num_splits = 0
|
||||
if (self.use_full_cuda_graph
|
||||
and num_actual_tokens <= self.max_cudagraph_size):
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if not current_platform.is_rocm():
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
if not current_platform.is_rocm():
|
||||
reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
reshape_and_cache_cuda(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
scheduler_metadata = local_metadata.local_scheduler_metadata
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
if not current_platform.is_rocm():
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=attn_metadata.max_num_splits,
|
||||
)
|
||||
else:
|
||||
if envs.VLLM_USE_PA_PRINT_PARAM:
|
||||
print("PA SIZE:")
|
||||
print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
|
||||
print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
|
||||
print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
|
||||
|
||||
vllm_flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
# fa_version=self.vllm_flash_attn_version,
|
||||
# q_descale=layer._q_scale.expand(descale_shape),
|
||||
# k_descale=layer._k_scale.expand(descale_shape),
|
||||
# v_descale=layer._v_scale.expand(descale_shape),
|
||||
# num_splits=attn_metadata.max_num_splits,
|
||||
is_prefix_cache=True,
|
||||
)
|
||||
return output
|
||||
|
||||
assert not use_local_attn, (
|
||||
"Cascade attention does not support local attention.")
|
||||
# Cascade attention (rare case).
|
||||
if not current_platform.is_rocm():
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
||||
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
q_descale=layer._q_scale,
|
||||
k_descale=layer._k_scale,
|
||||
v_descale=layer._v_scale,
|
||||
)
|
||||
else:
|
||||
cascade_attention(
|
||||
output[:num_actual_tokens],
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
cu_query_lens=attn_metadata.query_start_loc,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
||||
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
||||
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
||||
max_kv_len=attn_metadata.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
block_table=attn_metadata.block_table,
|
||||
common_prefix_len=attn_metadata.common_prefix_len,
|
||||
fa_version=2, #self.vllm_flash_attn_version,
|
||||
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
|
||||
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
# q_descale=layer._q_scale,
|
||||
# k_descale=layer._k_scale,
|
||||
# v_descale=layer._v_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def use_cascade_attention(
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
"""Decide whether to use cascade attention.
|
||||
|
||||
This function 1) checks whether cascade attention is supported with the
|
||||
given configuration, and 2) heuristically decides whether using cascade
|
||||
attention can improve performance.
|
||||
"""
|
||||
# Too short common prefix. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||
# possible to avoid any unnecessary computation.
|
||||
if common_prefix_len < 256:
|
||||
return False
|
||||
# Cascade attention is currently not supported with these variants.
|
||||
if use_alibi or use_sliding_window:
|
||||
return False
|
||||
# Too few queries. Probably not worth using cascade attention.
|
||||
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
||||
num_reqs = len(query_lens)
|
||||
if num_reqs < 8:
|
||||
return False
|
||||
|
||||
# Heuristics to decide whether using cascade attention is beneficial.
|
||||
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
||||
# is likely to be faster since it saves memory bandwidth.
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
# The criteria for using FlashDecoding can be found in the following link:
|
||||
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
||||
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
||||
and not use_alibi and np.all(query_lens == 1))
|
||||
if not use_flash_decoding:
|
||||
# Use cascade attention.
|
||||
return True
|
||||
|
||||
# 2. When FlashDecoding is used for normal attention, it is not clear
|
||||
# whether cascade attention is beneficial, because FlashDecoding can
|
||||
# launch more CTAs than cascade attention.
|
||||
# We use a simple performance model to compare the two methods.
|
||||
# NOTE(woosuk): The performance model is very rough and may not be
|
||||
# accurate.
|
||||
num_tokens = num_reqs
|
||||
# NOTE(woosuk): These are default tile sizes. flash-attn might use
|
||||
# different tile sizes (e.g., 64 or 256) depending on the configuration.
|
||||
q_tile_size = 128
|
||||
kv_tile_size = 128
|
||||
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
|
||||
|
||||
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
|
||||
cascade_waves = cdiv(cascade_ctas, num_sms)
|
||||
cascade_time = cascade_waves * num_prefix_tiles
|
||||
|
||||
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
||||
cdiv(num_queries_per_kv, q_tile_size))
|
||||
flash_decoding_ctas *= num_prefix_tiles
|
||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||
|
||||
# Use cascade attention if it is faster than FlashDecoding.
|
||||
return cascade_time < flash_decoding_time
|
||||
|
||||
|
||||
def cascade_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
cu_query_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
cu_prefix_query_lens: torch.Tensor,
|
||||
prefix_kv_lens: torch.Tensor,
|
||||
suffix_kv_lens: torch.Tensor,
|
||||
max_kv_len: int,
|
||||
softmax_scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: tuple[int, int],
|
||||
logits_soft_cap: float,
|
||||
block_table: torch.Tensor,
|
||||
common_prefix_len: int,
|
||||
fa_version: int,
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||
suffix_scheduler_metadata: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
||||
# TODO: Support sliding window.
|
||||
assert sliding_window == (-1, -1), (
|
||||
"Cascade attention does not support sliding window.")
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
block_size = key_cache.shape[-3]
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
assert num_common_kv_blocks > 0
|
||||
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process shared prefix.
|
||||
if not current_platform.is_rocm():
|
||||
prefix_output, prefix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
seqused_k=prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=prefix_scheduler_metadata,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
else:
|
||||
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_prefix_query_lens,
|
||||
seqused_k=prefix_kv_lens,
|
||||
max_seqlen_q=num_tokens,
|
||||
max_seqlen_k=common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:1],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=prefix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape)
|
||||
# if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape)
|
||||
# if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape)
|
||||
# if v_descale is not None else None,
|
||||
is_prefix_cache=True,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
|
||||
# Process suffix per query.
|
||||
if not current_platform.is_rocm():
|
||||
suffix_output, suffix_lse = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=suffix_scheduler_metadata,
|
||||
fa_version=fa_version,
|
||||
q_descale=q_descale.expand(descale_shape)
|
||||
if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape)
|
||||
if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape)
|
||||
if v_descale is not None else None,
|
||||
)
|
||||
else:
|
||||
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=suffix_kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len - common_prefix_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=sliding_window,
|
||||
block_table=block_table[:, num_common_kv_blocks:],
|
||||
softcap=logits_soft_cap,
|
||||
return_softmax_lse=True,
|
||||
scheduler_metadata=suffix_scheduler_metadata,
|
||||
# fa_version=fa_version,
|
||||
# q_descale=q_descale.expand(descale_shape)
|
||||
# if q_descale is not None else None,
|
||||
# k_descale=k_descale.expand(descale_shape)
|
||||
# if k_descale is not None else None,
|
||||
# v_descale=v_descale.expand(descale_shape)
|
||||
# if v_descale is not None else None,
|
||||
is_prefix_cache=True,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
680
vllm/v1/attention/backends/flashinfer.py
Normal file
680
vllm/v1/attention/backends/flashinfer.py
Normal file
@@ -0,0 +1,680 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashInfer."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
|
||||
return [64, 128, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type[FlashInferImpl]:
|
||||
return FlashInferImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[FlashInferMetadata]:
|
||||
return FlashInferMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
|
||||
return FlashInferMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets us from
|
||||
# `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters.
|
||||
"""
|
||||
|
||||
window_left: int
|
||||
logits_soft_cap: Optional[float]
|
||||
sm_scale: float
|
||||
|
||||
|
||||
def get_per_layer_parameters(
|
||||
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
|
||||
"""
|
||||
Scan all attention layers and determine some hyperparameters
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
# Infer hyperparameters from the attention layer
|
||||
window_size = impl.sliding_window
|
||||
window_left = window_size[0] if window_size is not None else -1
|
||||
logits_soft_cap = impl.logits_soft_cap
|
||||
sm_scale = impl.scale
|
||||
|
||||
per_layer_params[key] = PerLayerParameters(window_left,
|
||||
logits_soft_cap, sm_scale)
|
||||
|
||||
return per_layer_params
|
||||
|
||||
|
||||
def infer_global_hyperparameters(
|
||||
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
||||
"""
|
||||
Currently, FlashInfer backend only support models in which all layers share
|
||||
the same values for the following hyperparameters:
|
||||
- `window_left`
|
||||
- `logits_soft_cap`
|
||||
- `sm_scale`
|
||||
|
||||
So this function asserts that all layers share the same values for these
|
||||
hyperparameters and returns the global values.
|
||||
"""
|
||||
|
||||
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
||||
|
||||
param_sets = list(per_layer_params.values())
|
||||
global_params = param_sets[0]
|
||||
for params in param_sets:
|
||||
assert params == global_params, (
|
||||
"FlashInfer backend currently only supports models in which all "
|
||||
"layers share the same values for the following hyperparameters: "
|
||||
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
||||
|
||||
return global_params
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
qo_indptr: torch.Tensor
|
||||
# An example for paged_kv_indices, paged_kv_indptr:
|
||||
# request 1, page indices [0, 5, 8]
|
||||
# request 2, page indices [1, 6, 7]
|
||||
# request 3, page indices [3, 4]
|
||||
# paged_kv_indices is a concatenation of page indices of all requests:
|
||||
# [0, 5, 8, 1, 6, 7, 3, 4]
|
||||
# paged_kv_indptr is used to index into paged_kv_indices:
|
||||
# [0, 3, 6, 8]
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor
|
||||
# The number of query/output heads
|
||||
num_qo_heads: int
|
||||
# The number of key/value heads
|
||||
num_kv_heads: int
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
# Block size of vllm
|
||||
page_size: int
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype
|
||||
# The data type of the query
|
||||
q_data_type: torch.dtype
|
||||
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
shared_qo_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indptr: Optional[torch.Tensor] = None
|
||||
shared_kv_page_indices: Optional[torch.Tensor] = None
|
||||
shared_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
|
||||
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
|
||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
|
||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
|
||||
|
||||
@property
|
||||
def query_start_loc(self):
|
||||
# The GPUModelRunner expects to be able to access this property.
|
||||
return self.qo_indptr
|
||||
|
||||
def __post_init__(self):
|
||||
if self.head_dim is not None:
|
||||
FlashInferBackend.validate_head_size(self.head_dim)
|
||||
|
||||
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
# Global hyperparameters shared by all attention layers
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = runner.vllm_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
get_kv_cache_layout(),
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
def _get_cascade_wrapper(self):
|
||||
if self._cascade_wrapper is None:
|
||||
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
if self.global_hyperparameters is None:
|
||||
self.global_hyperparameters = infer_global_hyperparameters(
|
||||
get_per_layer_parameters(self.vllm_config))
|
||||
if attn_metadata.use_cascade:
|
||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||
attn_metadata.cascade_wrapper.plan(
|
||||
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indptr,
|
||||
attn_metadata.paged_kv_indptr
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_page_indices,
|
||||
attn_metadata.paged_kv_indices
|
||||
],
|
||||
[
|
||||
attn_metadata.shared_kv_last_page_len,
|
||||
attn_metadata.paged_kv_last_page_len
|
||||
],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if self._num_prefills > 0:
|
||||
# Decodes are first so prefills start after the last decode
|
||||
prefill_start = self._num_decodes
|
||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
|
||||
assert attn_metadata.qo_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
|
||||
0] == self._num_prefills + 1
|
||||
assert attn_metadata.paged_kv_last_page_len[
|
||||
prefill_start:].shape[0] == self._num_prefills
|
||||
# Since prefill_wrapper.run() will be called with
|
||||
# query[num_decode_tokens:] we need to adjust the qo_indptr
|
||||
# to be relative to the start of the prefill queries.
|
||||
qo_indptr = attn_metadata.qo_indptr[
|
||||
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr,
|
||||
attn_metadata.paged_kv_indptr[prefill_start:],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[prefill_start:],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
causal=True,
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
|
||||
if self._num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
attn_metadata.decode_wrapper.plan(
|
||||
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
|
||||
attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads,
|
||||
attn_metadata.head_dim,
|
||||
attn_metadata.page_size,
|
||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||
pos_encoding_mode="NONE",
|
||||
sm_scale=self.global_hyperparameters.sm_scale,
|
||||
window_left=self.global_hyperparameters.window_left,
|
||||
logits_soft_cap=self.global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=attn_metadata.q_data_type,
|
||||
kv_data_type=attn_metadata.data_type,
|
||||
)
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
if use_cascade:
|
||||
# Grab the blocks of the shared prefix from the first request.
|
||||
assert common_prefix_len % page_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // page_size
|
||||
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indices = block_table_tensor[
|
||||
0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr = None
|
||||
shared_kv_page_indptr = None
|
||||
shared_kv_page_indices = None
|
||||
shared_kv_last_page_len = None
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=block_table_tensor.device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1,
|
||||
dtype=block_table_bounds.dtype,
|
||||
device=block_table_bounds.device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
attn_metadata = FlashInferMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
qo_indptr=qo_indptr,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
data_type=self.kv_cache_spec.dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
use_cascade=use_cascade,
|
||||
shared_qo_indptr=shared_qo_indptr,
|
||||
shared_kv_page_indptr=shared_kv_page_indptr,
|
||||
shared_kv_page_indices=shared_kv_page_indices,
|
||||
shared_kv_last_page_len=shared_kv_last_page_len,
|
||||
)
|
||||
|
||||
self._plan(attn_metadata)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashInfer.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashInferImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[:, 0],
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
window_left = (self.sliding_window[0]
|
||||
if self.sliding_window is not None else -1)
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
|
||||
if attn_metadata.use_cascade:
|
||||
# Cascade attention (rare case).
|
||||
assert attn_metadata.cascade_wrapper is not None
|
||||
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
|
||||
return output
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
if prefill_wrapper := attn_metadata.prefill_wrapper:
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
assert prefill_query.shape[0] == num_prefill_tokens
|
||||
assert prefill_wrapper is not None
|
||||
assert prefill_wrapper._causal
|
||||
assert prefill_wrapper._window_left == window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if decode_wrapper := attn_metadata.decode_wrapper:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
assert decode_wrapper._window_left == window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
|
||||
or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
)
|
||||
|
||||
return output_padded
|
||||
491
vllm/v1/attention/backends/flex_attention.py
Normal file
491
vllm/v1/attention/backends/flex_attention.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
create_block_mask,
|
||||
flex_attention)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
create_block_mask_compiled = torch.compile(create_block_mask,
|
||||
fullgraph=True,
|
||||
mode="reduce-overhead")
|
||||
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
|
||||
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
device = offsets.device
|
||||
counts = offsets[1:] - offsets[:-1]
|
||||
return torch.repeat_interleave(
|
||||
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
|
||||
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
return # FlexAttention supports any head size
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return FlexAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
|
||||
return FlexAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# @torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def physical_to_logical_mapping(
|
||||
block_table: torch.Tensor,
|
||||
total_blocks: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
Creates an inverse mapping from physical block locations to logical indices.
|
||||
|
||||
The original block_table maps from logical blocks to physical locations:
|
||||
|
||||
Logical to Physical (Original block_table):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
This function creates the inverse mapping:
|
||||
|
||||
Physical to Logical (Inverse mapping):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
If multiple logical blocks map to the same physical block,
|
||||
this function returns the first (minimum) logical block index.
|
||||
|
||||
If a physical block is not mapped to by any logical block,
|
||||
its value in the result will be -1.
|
||||
|
||||
|
||||
Args:
|
||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||
mapping logical blocks to physical locations
|
||||
|
||||
Returns:
|
||||
A tensor of shape [max_reqs, max_physical_block]
|
||||
"""
|
||||
max_reqs, max_num_blocks = block_table.shape
|
||||
device = block_table.device
|
||||
|
||||
physical_to_logical = torch.full((max_reqs, total_blocks),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
logical_indices = (torch.arange(max_num_blocks,
|
||||
device=device).unsqueeze(0).expand(
|
||||
max_reqs, -1))
|
||||
|
||||
physical_to_logical.scatter_(-1, block_table.to(torch.int64),
|
||||
logical_indices)
|
||||
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
|
||||
physical_to_logical[:, 0] = -1
|
||||
return physical_to_logical
|
||||
|
||||
|
||||
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Block info
|
||||
total_cache_tokens: int
|
||||
block_size: int
|
||||
max_possible_sequence_length: int
|
||||
num_reqs: int
|
||||
physical_to_logical: torch.Tensor
|
||||
decode_offset: torch.Tensor
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# Flex Metadata
|
||||
num_blocks = 0
|
||||
block_mask: Optional[BlockMask] = None
|
||||
score_mod: Optional[_score_mod_signature] = None
|
||||
mask_mod: Optional[_mask_mod_signature] = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
|
||||
def get_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
1. The paged attention block mapping
|
||||
2. The mapping from packed query sequences to logical query entries
|
||||
|
||||
It also by defaults adds the decoding offset to the query indices.
|
||||
With this info we create the "logical" indices that are passed to
|
||||
mask_mod functions. This allows mask mod functions to be agnostic to
|
||||
layout of the query and key/value tensors.
|
||||
|
||||
TODO is_within_lower_bound: do sequences start on block_boundaries?
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Map query indices to corresponding request indices
|
||||
q_req = request_lookup[q_idx]
|
||||
|
||||
# Convert physical KV indices to logical indices
|
||||
physical_kv_block = physical_kv_idx // self.block_size
|
||||
physical_kv_offset = physical_kv_idx % self.block_size
|
||||
logical_block_idx = self.physical_to_logical[q_req,
|
||||
physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501
|
||||
|
||||
# Determine valid kv indices
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
# Convert physical query indices to logical indices
|
||||
local_q_idx = q_idx - self.query_start_loc[q_req]
|
||||
logical_q_idx = local_q_idx + self.decode_offset[q_req]
|
||||
|
||||
# Apply mask modification only for valid indices
|
||||
return torch.where(
|
||||
is_valid,
|
||||
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
assert self.mask_mod is not None
|
||||
return create_block_mask_compiled(
|
||||
self.mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
self.total_cache_tokens,
|
||||
device=self.block_table.device,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.use_cascade is False, "Not implemented yet."
|
||||
assert self.common_prefix_len == 0, "Not implemented yet."
|
||||
assert self.cu_prefix_query_lens is None, "Not implemented yet."
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
if use_cascade:
|
||||
raise NotImplementedError("Not yet my friend")
|
||||
|
||||
block_size = self.kv_cache_spec.block_size
|
||||
max_possible_seq_len = self.runner.model_config.max_model_len
|
||||
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
|
||||
block_size)
|
||||
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
|
||||
|
||||
# Get the original offset tensor
|
||||
offset_tensor = torch.tensor(
|
||||
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
block_size=block_size,
|
||||
max_possible_sequence_length=max_possible_seq_len,
|
||||
num_reqs=num_reqs,
|
||||
physical_to_logical=inverse_block_table,
|
||||
total_cache_tokens=total_cache_tokens,
|
||||
decode_offset=offset_tensor,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[tuple[int, int]]
|
||||
alibi_slopes: Optional[torch.Tensor]
|
||||
logits_soft_cap: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
# TODO we should support this :think
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support alibi slopes yet.")
|
||||
else:
|
||||
self.alibi_slopes = None
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support sliding window yet.")
|
||||
else:
|
||||
self.sliding_window = (-1, -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if self.logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support logits soft cap yet.")
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support kv sharing yet.")
|
||||
|
||||
FlexAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet")
|
||||
|
||||
@staticmethod
|
||||
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""View a 3d tensor as 4D."""
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
assert tensor.ndim == 3
|
||||
return tensor[None, :, :, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlexAttentionImpl")
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||
# return torch.empty_like(query)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_cache, value_cache = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
# default M=64, N=64 may run out of shared memory on some GPUs
|
||||
# TODO: Explicit configs for each GPU?
|
||||
# Not sure how to calculate the shared memory requirement
|
||||
extra_kernel_options = defaultdict[str, int](lambda: 64)
|
||||
if query.dtype == torch.float32:
|
||||
extra_kernel_options["BLOCK_M"] //= 2
|
||||
extra_kernel_options["BLOCK_N"] //= 2
|
||||
if current_platform.is_cuda():
|
||||
device_props = torch.cuda.get_device_properties()
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
if max_shared_memory < 144 * 1024:
|
||||
extra_kernel_options["BLOCK_M"] //= 2
|
||||
extra_kernel_options["BLOCK_N"] //= 2
|
||||
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options={
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
**extra_kernel_options
|
||||
},
|
||||
)
|
||||
|
||||
# Flex doesn't have an out variant today, rely on epilogue fusion
|
||||
out = out.permute(0, 2, 1, 3).squeeze(0)
|
||||
output[:num_actual_tokens, :, :].copy_(out)
|
||||
return output
|
||||
192
vllm/v1/attention/backends/mamba_attn.py
Normal file
192
vllm/v1/attention/backends/mamba_attn.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
|
||||
chunk_sizes = set(layer.chunk_size for layer in layers.values())
|
||||
assert len(
|
||||
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
|
||||
return chunk_sizes.pop()
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
|
||||
# FlashInferMetadataBuilder. Should be refactored later to avoid code
|
||||
# duplication of these 3 functions.
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
prep_initial_states = False
|
||||
|
||||
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if self._num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[num_reqs -
|
||||
self._num_prefills:num_reqs]
|
||||
> 0)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-self._num_prefills - 1:] - self._num_decode_tokens
|
||||
|
||||
seq_idx = torch.repeat_interleave(
|
||||
torch.arange(self._num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=self._num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
# model forward and reuse them in mamba layers. If not needed,
|
||||
# they will be ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = (
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
self._num_prefill_tokens))
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
1114
vllm/v1/attention/backends/mla/common.py
Normal file
1114
vllm/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
250
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
Normal file
250
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from functools import reduce
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
@pytest.mark.parametrize("shape_pair,dim", [
|
||||
|
||||
(((4, 8, 512), (4, 8, 64)), 2),
|
||||
(((8, 8, 512), (8, 8, 64)), 2),
|
||||
(((16, 8, 512), (16, 8, 64)), 2),
|
||||
(((32, 8, 512), (32, 8, 64)), 2),
|
||||
(((64, 8, 512), (64, 8, 64)), 2),
|
||||
(((128, 8, 512), (128, 8, 64)), 2),
|
||||
(((256, 8, 512), (256, 8, 64)), 2),
|
||||
(((512, 8, 512), (512, 8, 64)), 2),
|
||||
(((672, 8, 512), (672, 8, 64)), 2),
|
||||
(((768, 8, 512), (768, 8, 64)), 2),
|
||||
(((896, 8, 512), (896, 8, 64)), 2),
|
||||
(((1024, 8, 512), (1024, 8, 64)), 2),
|
||||
|
||||
(((4, 16, 512), (4, 16, 64)), 2),
|
||||
(((8, 16, 512), (8, 16, 64)), 2),
|
||||
(((16, 16, 512), (16, 16, 64)), 2),
|
||||
(((32, 16, 512), (32, 16, 64)), 2),
|
||||
(((64, 16, 512), (64, 16, 64)), 2),
|
||||
(((128, 16, 512), (128, 16, 64)), 2),
|
||||
(((256, 16, 512), (256, 16, 64)), 2),
|
||||
(((512, 16, 512), (512, 16, 64)), 2),
|
||||
(((672, 16, 512), (672, 16, 64)), 2),
|
||||
(((768, 16, 512), (768, 16, 64)), 2),
|
||||
(((896, 16, 512), (896, 16, 64)), 2),
|
||||
(((1024, 16, 512), (1024, 16, 64)), 2),
|
||||
|
||||
|
||||
(((4, 32, 512), (4, 32, 64)), 2),
|
||||
(((8, 32, 512), (8, 32, 64)), 2),
|
||||
(((16, 32, 512), (16, 32, 64)), 2),
|
||||
(((32, 32, 512), (32, 32, 64)), 2),
|
||||
(((64, 32, 512), (64, 32, 64)), 2),
|
||||
(((128, 32, 512), (128, 32, 64)), 2),
|
||||
(((256, 32, 512), (256, 32, 64)), 2),
|
||||
(((512, 32, 512), (512, 32, 64)), 2),
|
||||
(((672, 32, 512), (672, 32, 64)), 2),
|
||||
(((768, 32, 512), (768, 32, 64)), 2),
|
||||
(((896, 32, 512), (896, 32, 64)), 2),
|
||||
(((1024, 32, 512), (1024, 32, 64)), 2),
|
||||
|
||||
|
||||
|
||||
(((4, 32, 128), (4, 32, 64)), 2),
|
||||
(((8, 32, 128), (8, 32, 64)), 2),
|
||||
(((16, 32, 128), (16, 32, 64)), 2),
|
||||
(((32, 32, 128), (32, 32, 64)), 2),
|
||||
(((64, 32, 128), (64, 32, 64)), 2),
|
||||
(((128, 32, 128), (128, 32, 64)), 2),
|
||||
(((256, 32, 128), (256, 32, 64)), 2),
|
||||
(((512, 32, 128), (512, 32, 64)), 2),
|
||||
(((672, 32, 128), (672, 32, 64)), 2),
|
||||
(((768, 32, 128), (768, 32, 64)), 2),
|
||||
(((896, 32, 128), (896, 32, 64)), 2),
|
||||
(((1024, 32, 128), (1024, 32, 64)), 2),
|
||||
|
||||
])
|
||||
def test_concat_Acc(shape_pair, dim):
|
||||
|
||||
torch.manual_seed(1)
|
||||
shape1, shape2 = shape_pair
|
||||
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16)
|
||||
expected = torch.cat([x,y], dim=dim)
|
||||
result = concat_helper(x, y, dim=dim)
|
||||
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel_prefill(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)# 获取当前block的索引
|
||||
|
||||
for sub_section_index in range(Per_block//2):
|
||||
sub_section_offset = block_idx * Per_block + sub_section_index * 2
|
||||
if sub_section_offset <= section_num-1:
|
||||
|
||||
C_section_start = C_ptr + sub_section_offset * C_section_numel
|
||||
A_section_start = A_ptr + sub_section_offset * A_section_numel
|
||||
B_section_start = B_ptr + sub_section_offset * B_section_numel
|
||||
|
||||
Arrange_doubleA = tl.arange(0, 256)
|
||||
mask = Arrange_doubleA < (256)
|
||||
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
|
||||
val_from_A = tl.load(A_section_start + Arrange_doubleA)
|
||||
tensorAsn = tl.full((256,), 0, tl.int32)
|
||||
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
|
||||
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
|
||||
off = Arrange2 + tensor_offsets
|
||||
tl.store(C_section_start + off,val_from_A,mask=mask)
|
||||
|
||||
Arrange_doubleB = tl.arange(0, 128)
|
||||
mask = Arrange_doubleB < (B_section_numel*2)
|
||||
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
|
||||
|
||||
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
|
||||
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
|
||||
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
|
||||
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
|
||||
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
for sub_section_index in range(Per_block):
|
||||
sub_offset = block_idx * Per_block + sub_section_index
|
||||
if sub_offset <= section_num-1:
|
||||
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
|
||||
A_ptr_block_start = A_ptr + sub_offset * A_section_numel
|
||||
B_ptr_block_start = B_ptr + sub_offset * B_section_numel
|
||||
|
||||
for offset in range(0, A_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < A_section_numel
|
||||
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
|
||||
|
||||
for offset in range(0, B_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < B_section_numel
|
||||
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
|
||||
|
||||
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
|
||||
A = A.contiguous()
|
||||
B = B.contiguous()
|
||||
output_shape = list(A.shape)
|
||||
output_shape[dim] = A.shape[dim] + B.shape[dim]
|
||||
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
|
||||
|
||||
if dim!=0 :
|
||||
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
|
||||
Per_block = 1
|
||||
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
|
||||
#case prefill
|
||||
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
|
||||
Per_block = 8
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel_prefill[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
|
||||
else:
|
||||
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
|
||||
Per_block = 2
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
assert False, "not support"
|
||||
|
||||
|
||||
configs = []
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'],
|
||||
x_vals=[4,8,16,32,64,96,128,256,512,768,1024],
|
||||
x_log=True,
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
styles=[('blue', '-'), ('green', '-')],
|
||||
ylabel='s',
|
||||
plot_name='concat-dim2',
|
||||
args={"dim":2},
|
||||
),
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(size, provider, dim):
|
||||
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_16(size, provider, dim):
|
||||
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_32(size, provider, dim):
|
||||
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_prefill(size, provider, dim):
|
||||
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# benchmark.run(save_path="./triton_test_8",print_data=True)
|
||||
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
|
||||
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
|
||||
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
|
||||
|
||||
|
||||
248
vllm/v1/attention/backends/mla/concatv4_decode_only.py
Normal file
248
vllm/v1/attention/backends/mla/concatv4_decode_only.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from functools import reduce
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
@pytest.mark.parametrize("shape_pair,dim", [
|
||||
|
||||
(((4, 8, 512), (4, 8, 64)), 2),
|
||||
(((8, 8, 512), (8, 8, 64)), 2),
|
||||
(((16, 8, 512), (16, 8, 64)), 2),
|
||||
(((32, 8, 512), (32, 8, 64)), 2),
|
||||
(((64, 8, 512), (64, 8, 64)), 2),
|
||||
(((128, 8, 512), (128, 8, 64)), 2),
|
||||
(((256, 8, 512), (256, 8, 64)), 2),
|
||||
(((512, 8, 512), (512, 8, 64)), 2),
|
||||
(((672, 8, 512), (672, 8, 64)), 2),
|
||||
(((768, 8, 512), (768, 8, 64)), 2),
|
||||
(((896, 8, 512), (896, 8, 64)), 2),
|
||||
(((1024, 8, 512), (1024, 8, 64)), 2),
|
||||
|
||||
(((4, 16, 512), (4, 16, 64)), 2),
|
||||
(((8, 16, 512), (8, 16, 64)), 2),
|
||||
(((16, 16, 512), (16, 16, 64)), 2),
|
||||
(((32, 16, 512), (32, 16, 64)), 2),
|
||||
(((64, 16, 512), (64, 16, 64)), 2),
|
||||
(((128, 16, 512), (128, 16, 64)), 2),
|
||||
(((256, 16, 512), (256, 16, 64)), 2),
|
||||
(((512, 16, 512), (512, 16, 64)), 2),
|
||||
(((672, 16, 512), (672, 16, 64)), 2),
|
||||
(((768, 16, 512), (768, 16, 64)), 2),
|
||||
(((896, 16, 512), (896, 16, 64)), 2),
|
||||
(((1024, 16, 512), (1024, 16, 64)), 2),
|
||||
|
||||
|
||||
(((4, 32, 512), (4, 32, 64)), 2),
|
||||
(((8, 32, 512), (8, 32, 64)), 2),
|
||||
(((16, 32, 512), (16, 32, 64)), 2),
|
||||
(((32, 32, 512), (32, 32, 64)), 2),
|
||||
(((64, 32, 512), (64, 32, 64)), 2),
|
||||
(((128, 32, 512), (128, 32, 64)), 2),
|
||||
(((256, 32, 512), (256, 32, 64)), 2),
|
||||
(((512, 32, 512), (512, 32, 64)), 2),
|
||||
(((672, 32, 512), (672, 32, 64)), 2),
|
||||
(((768, 32, 512), (768, 32, 64)), 2),
|
||||
(((896, 32, 512), (896, 32, 64)), 2),
|
||||
(((1024, 32, 512), (1024, 32, 64)), 2),
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
])
|
||||
def test_concat_Acc(shape_pair, dim):
|
||||
|
||||
torch.manual_seed(1)
|
||||
shape1, shape2 = shape_pair
|
||||
M = shape1[0]
|
||||
N = shape1[1]
|
||||
x_sizes = [M, N, 512]
|
||||
x_strides = [512, 512*M, 1]
|
||||
x_max_index = M * N * 512
|
||||
x_required_length = x_max_index
|
||||
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
|
||||
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
|
||||
|
||||
# print("形状:", x.shape) # [4, 8, 512]
|
||||
# print("步幅:", x.stride()) # (1536, 192, 1)
|
||||
|
||||
y_sizes = [M, N, 64]
|
||||
y_strides = [1536*(N//8), 192, 1]
|
||||
y_max_index = 1536*(N//8) * M
|
||||
y_required_length = y_max_index
|
||||
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
|
||||
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
|
||||
|
||||
|
||||
|
||||
expected = torch.cat([x,y], dim=dim)
|
||||
result = concat_helper(x, y, dim=dim)
|
||||
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
M,
|
||||
N,
|
||||
Astride_0,
|
||||
Astride_1,
|
||||
Astride_2,
|
||||
Bstride_0,
|
||||
Bstride_1,
|
||||
Bstride_2,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
for sub_section_index in range(Per_block):
|
||||
sub_offset = block_idx * Per_block + sub_section_index
|
||||
M_idx = sub_offset // N
|
||||
N_idx = sub_offset % N
|
||||
if sub_offset <= section_num-1:
|
||||
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
|
||||
A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
|
||||
B_ptr_block_start = B_ptr + M_idx * Bstride_0 + N_idx * Bstride_1
|
||||
|
||||
for offset in range(0, A_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < A_section_numel
|
||||
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
|
||||
|
||||
for offset in range(0, B_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < B_section_numel
|
||||
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
|
||||
|
||||
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
|
||||
|
||||
output_shape = list(A.shape)
|
||||
output_shape[dim] = A.shape[dim] + B.shape[dim]
|
||||
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
|
||||
|
||||
if dim!=0 :
|
||||
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
|
||||
Per_block = 1
|
||||
unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
|
||||
if (A.shape[1]==8 and A.shape[0] > 512) or ( A.shape[1]==16 and A.shape[0] > 256):
|
||||
Per_block = 2
|
||||
if ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 256):
|
||||
Per_block = 8
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
output_shape[0],
|
||||
output_shape[1],
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
|
||||
assert False, "not support"
|
||||
|
||||
|
||||
configs = []
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M','N'],
|
||||
x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(1024,8), \
|
||||
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(1024,16), \
|
||||
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(1024,32)],
|
||||
x_log=True,
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
styles=[('blue', '-'), ('green', '-')],
|
||||
ylabel='s',
|
||||
plot_name='concat-dim2',
|
||||
args={"dim":2},
|
||||
),
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(M, N, provider, dim):
|
||||
|
||||
x_sizes = [M, N, 512]
|
||||
x_strides = [512, 512*M, 1]
|
||||
x_max_index = M * N * 512
|
||||
x_required_length = x_max_index
|
||||
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
|
||||
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
|
||||
|
||||
# print("形状:", x.shape) # [M, 8, 512]
|
||||
# print("步幅:", x.stride()) # (512, 512*M, 1)
|
||||
|
||||
y_sizes = [M, N, 64]
|
||||
y_strides = [1536*(N//8), 192, 1]
|
||||
y_max_index = 1536*(N//8) * M
|
||||
y_required_length = y_max_index
|
||||
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
|
||||
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
|
||||
|
||||
# print("形状:", y.shape) # [M, 8, 64]
|
||||
# print("步幅:", y.stride()) # (1536, 192, 1)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_16(size, provider, dim):
|
||||
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_32(size, provider, dim):
|
||||
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_prefill(size, provider, dim):
|
||||
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
if __name__ == '__main__':
|
||||
benchmark.run(save_path="./triton_test",print_data=True)
|
||||
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
|
||||
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
|
||||
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
|
||||
|
||||
|
||||
98
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
98
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
195
vllm/v1/attention/backends/mla/flashmla.py
Normal file
195
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm import envs
|
||||
from vllm.v1.attention.backends.mla.concatv4_decode_only import concat_helper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
if self.kv_cache_dtype != "fp8":
|
||||
raise NotImplementedError(
|
||||
"FlashMLA with other KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
k_scale = None,
|
||||
kv_cache_dtype = "auto",
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if envs.VLLM_USE_TRITON_CAT:
|
||||
if q_nope.shape[0] <= 1024:
|
||||
q = concat_helper(q_nope, q_pe, dim=-1)\
|
||||
.unsqueeze(1)
|
||||
else:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
else:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
k_scale = k_scale,
|
||||
kv_cache_dtype = kv_cache_dtype,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
241
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
241
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
# Preparing persistent buffers
|
||||
if self.runner.full_cuda_graph:
|
||||
device = self.runner.device
|
||||
max_num_reqs = self.runner.max_num_reqs
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
block_table.get_device_tensor().numel(
|
||||
), # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
self.qo_indptr = torch.arange(0,
|
||||
max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
num_reqs = self._num_decodes
|
||||
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[:1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
assert (num_heads == 16 or num_heads == 128), (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value.")
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
177
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
177
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
self.triton_fa_func = triton_attention if HAS_TRITON else None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims_rocm(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
assert self.triton_fa_func is not None
|
||||
|
||||
# Triton Attention requires a padded V
|
||||
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
# The output of triton_attention is a tuple of
|
||||
# [output_tensor, encoded_softmax] where encoded_softmax is always None
|
||||
output_tensor, _ = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
padded_v,
|
||||
None, # output
|
||||
kwargs["cu_seqlens_q"],
|
||||
kwargs["cu_seqlens_k"],
|
||||
kwargs["max_seqlen_q"],
|
||||
kwargs["max_seqlen_k"],
|
||||
kwargs["causal"],
|
||||
softmax_scale,
|
||||
None, # bias
|
||||
)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
if current_platform.is_rocm() \
|
||||
and self.use_triton_flash_attn \
|
||||
and not return_softmax_lse:
|
||||
return self._flash_attn_varlen_diff_headdims_rocm(
|
||||
q, k, v, softmax_scale=softmax_scale, **kwargs)
|
||||
else:
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
320
vllm/v1/attention/backends/pallas.py
Normal file
320
vllm/v1/attention/backends/pallas.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
# Required to register custom ops.
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = cdiv(
|
||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||
@staticmethod
|
||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||
max_num_page_per_req = (1024 * 1024 // 2 //
|
||||
vllm_config.scheduler_config.max_num_seqs // 4)
|
||||
min_page_size = cdiv(vllm_config.model_config.max_model_len,
|
||||
max_num_page_per_req)
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
||||
num_page_per_req = cdiv(model_len, page_size)
|
||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||
@staticmethod
|
||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||
page_size = next_power_of_2(
|
||||
vllm_config.model_config.max_model_len) // 16
|
||||
if page_size <= 16:
|
||||
return 16
|
||||
if page_size >= 256:
|
||||
return 256
|
||||
return page_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_kv_update_slices: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError("Paged attention Pallas kernel does "
|
||||
"not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
tpu_version = torch_xla.tpu.version()
|
||||
if tpu_version < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl")
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
padded_head_size = cdiv(
|
||||
self.head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, padded_head_size - self.head_size), value=0.0)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, padded_head_size - self.head_size), value=0.0)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, padded_head_size - self.head_size), value=0.0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key, value, kv_cache, slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices)
|
||||
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
output = output[:, :, :self.head_size]
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
) -> None:
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size,
|
||||
TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||
head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv, slot_mapping, kv_cache, num_kv_update_slices, page_size,
|
||||
num_slices_per_kv_cache_update_block)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
|
||||
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
|
||||
"-> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
num_kv_update_slices, page_size,
|
||||
num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
609
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
609
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@triton.jit
|
||||
def _vllm_layout_trans_kernel(
|
||||
k_buffer_ptr,
|
||||
v_buffer_ptr,
|
||||
k_values_ptr,
|
||||
v_values_ptr,
|
||||
b_query_lens_loc,
|
||||
b_seq_lens_loc,
|
||||
block_table,
|
||||
block_table_stride_0,
|
||||
E_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
|
||||
tl.arange(0, 2))
|
||||
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
|
||||
seq_len = batch_token_end - batch_token_start
|
||||
|
||||
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
|
||||
tl.arange(0, 2))
|
||||
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
|
||||
query_len = batch_query_end - batch_query_start
|
||||
if query_len <= 1:
|
||||
return
|
||||
if block_idx * BLOCK_SIZE < seq_len:
|
||||
block_mask = (block_idx * BLOCK_SIZE +
|
||||
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
|
||||
|
||||
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
|
||||
block_idx)
|
||||
|
||||
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
|
||||
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
|
||||
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
|
||||
mask=block_mask,
|
||||
other=0.0)
|
||||
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
|
||||
mask=block_mask,
|
||||
other=0.0)
|
||||
|
||||
kv_values_off = batch_token_start * E_DIM + \
|
||||
block_idx * BLOCK_SIZE * E_DIM + \
|
||||
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
|
||||
tl.arange(0, E_DIM)[None, :]
|
||||
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
|
||||
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
|
||||
|
||||
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
|
||||
k_buffer, v_buffer, max_seq_len, total_tokens):
|
||||
H_KV = v_buffer.shape[2]
|
||||
D = v_buffer.shape[3]
|
||||
BLOCK_SIZE = v_buffer.shape[1]
|
||||
dtype = k_buffer.dtype
|
||||
k_values = torch.empty((total_tokens, H_KV, D),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
v_values = torch.empty((total_tokens, H_KV, D),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
grid = (block_table.shape[0],
|
||||
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
|
||||
|
||||
_vllm_layout_trans_kernel[grid](k_buffer,
|
||||
v_buffer,
|
||||
k_values,
|
||||
v_values,
|
||||
b_query_lens_loc,
|
||||
b_seq_lens_loc,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
E_DIM=H_KV * D,
|
||||
BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return k_values, v_values
|
||||
|
||||
def flash_attn_varlen_func_impl(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
total_tokens: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
block_table: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
|
||||
k_cache, v_cache, max_seqlen_k, total_tokens)
|
||||
output = aiter.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
min_seqlen_q=1,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
window_size=window_size,
|
||||
out=out,
|
||||
)
|
||||
return output
|
||||
|
||||
def flash_attn_varlen_func_fake(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
total_tokens: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
block_table: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(q.shape[0],
|
||||
q.shape[1],
|
||||
v_cache.shape[-2],
|
||||
dtype=torch.float8_e4m3fnuz,
|
||||
device="cuda")
|
||||
|
||||
direct_register_custom_op("flash_attn_varlen_func",
|
||||
flash_attn_varlen_func_impl, ["out"],
|
||||
flash_attn_varlen_func_fake,
|
||||
dispatch_key=current_platform.dispatch_key)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
torch.cumsum(seq_lens,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = int(seqlens_q_local_np.max())
|
||||
local_max_seq_len = int(virt_k_seqlens_np.max())
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
local_cu_seq_lens[1:] = torch.cumsum(
|
||||
torch.from_numpy(virt_k_seqlens_np).to(
|
||||
device=self.runner.device,
|
||||
dtype=torch.int32,
|
||||
non_blocking=True),
|
||||
dim=0)
|
||||
|
||||
|
||||
local_attn_metadata = \
|
||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_cu_seq_lens=local_cu_seq_lens,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_tokens=total_tokens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||
return AiterFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AiterFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
||||
return AiterFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_tokens: int
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_cu_seq_lens: torch.Tensor
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"AiterFlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = [-1, -1]
|
||||
else:
|
||||
self.sliding_window = [sliding_window - 1, 0]
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0.
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
AiterFlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"AiterFlashAttention does not support fp8 kv-cache on this "
|
||||
"device.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if max_seqlen_q > 1:
|
||||
cu_seq_lens = attn_metadata.cu_seq_lens
|
||||
total_tokens = attn_metadata.total_tokens
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
total_tokens=total_tokens,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
|
||||
local_metadata.local_cu_seq_lens),
|
||||
)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
num_seqs = seqused_k.shape[0]
|
||||
nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8
|
||||
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
|
||||
1) // _PARTITION_SIZE_ROCM
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size) *
|
||||
nbyes_per_qo_elem + 2 *
|
||||
(num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
aiter.paged_attention_v1(
|
||||
output[:num_actual_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
block_table,
|
||||
cu_seqlens_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention is not implemented for ROCM AITER")
|
||||
449
vllm/v1/attention/backends/triton_attn.py
Normal file
449
vllm/v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
prefix_scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
|
||||
local_attn_metadata = TritonAttentionMetadata \
|
||||
.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=None,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len],
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
||||
common_prefix_len)
|
||||
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
||||
self.runner.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = TritonAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
supported_head_sizes = cls.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {supported_head_sizes}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes.")
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return TritonAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[int] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.use_irope = use_irope
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
TritonAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonAttentionImpl")
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for TritonAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
use_prefill_decode_attn = self.force_prefill_decode_attn
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if use_prefill_decode_attn:
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
assert layer._q_scale == 1.0, \
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
if not current_platform.is_rocm():
|
||||
# Skip Q quantization on ROCm, since dequantizing back to
|
||||
# f32 in the attention kernel is not supported.
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
|
||||
else:
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
return output
|
||||
314
vllm/v1/attention/backends/utils.py
Normal file
314
vllm/v1/attention/backends/utils.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.utils import cdiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
"""Longest query in batch"""
|
||||
num_speculative_tokens: int = 0
|
||||
"""Number of speculative tokens"""
|
||||
slot_mapping: torch.Tensor = None
|
||||
"""(batch_size, seq_len), slot mapping"""
|
||||
spec_layer_decoding: bool = False
|
||||
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
full_cudagraph_supported: ClassVar[bool] = False
|
||||
|
||||
@abstractmethod
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Central method that builds attention metadata.
|
||||
Some builders (MLA) require reorder_batch to be called prior to build.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
"""
|
||||
Can this batch (with given metadata) use CUDA Graphs for attention.
|
||||
"""
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
query_lens: np.ndarray,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_alibi: bool,
|
||||
use_sliding_window: bool,
|
||||
num_sms: int,
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
This method can reorder the batch if desired by the backend.
|
||||
:return: Has the batch been reordered (default False).
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
static_forward_context):
|
||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} ")
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg +
|
||||
"cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg +
|
||||
"is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[
|
||||
target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
|
||||
return cache_layout
|
||||
|
||||
|
||||
#
|
||||
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
||||
# local attention blocks, where each block is passed to the attention kernel
|
||||
# as an independent local ("virtual") batch item.
|
||||
#
|
||||
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# kv_seqlens = [6, 17, 9]
|
||||
# Then normally for regular attention we would compute with an attention mask
|
||||
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1 1 1 1 1
|
||||
# 3 | 1 1 1 1 1 1
|
||||
#
|
||||
# for local attention (with attn_chunk_size = 4) we would compute with an
|
||||
# attention mask like:
|
||||
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
||||
# k_toks > 0 1 2 3 4 5
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# We can simulate this mask using standard flash-attention by breaking the
|
||||
# sequences into local ("virtual") batches, where each local batch item is a
|
||||
# local attention block, so in this case batch idx 0 would be broken up into:
|
||||
#
|
||||
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
||||
# k_toks > 0 1 2 3
|
||||
# q_toks v _____________
|
||||
# 0 | 1 1 1
|
||||
# 1 | 1 1 1 1
|
||||
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
||||
# k_toks > 4 5
|
||||
# q_toks v _____________
|
||||
# 2 | 1
|
||||
# 3 | 1 1
|
||||
#
|
||||
# e.g. if we have:
|
||||
# attn_chunk_size = 4
|
||||
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
||||
# Then this function would return:
|
||||
# __b0__ ______b1______ __b2__ < orig batch indices
|
||||
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
||||
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
||||
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
||||
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
||||
def make_local_attention_virtual_batches(
|
||||
attn_chunk_size: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
seq_lens_np: np.ndarray,
|
||||
block_table: torch.Tensor,
|
||||
block_size: int = 0,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
||||
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
||||
actual_batch_size = seq_lens_np.shape[0]
|
||||
|
||||
# Handle if we are starting in the middle of a local attention block,
|
||||
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
||||
# the number of tokens that are not in the first local attention block and
|
||||
# then we can simply use a cdiv for the rest.
|
||||
# For example if we have:
|
||||
# attn_chunk_size = 4
|
||||
# q_seqlens = [4, 10, 5]
|
||||
# k_seqlens = [6, 17, 9]
|
||||
# Then we would get:
|
||||
# new_tokens_in_first_block = [2, 1, 4]
|
||||
# local_blocks = [2, 4, 2]
|
||||
q_tokens_in_first_block = np.minimum(
|
||||
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
||||
q_seqlens).astype(np.int32)
|
||||
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
||||
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
||||
attn_chunk_size)
|
||||
|
||||
# Once we know the number of local blocks we can compute the request spans
|
||||
# for each batch idx, we can figure out the number of "virtual" requests we
|
||||
# have to make,
|
||||
# For the above example we would get:
|
||||
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
||||
#
|
||||
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
||||
# (TODO: max a utility to share this code with _prepare_inputs)
|
||||
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
||||
cu_num_blocks = np.cumsum(local_blocks)
|
||||
virtual_batches = cu_num_blocks[-1]
|
||||
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
||||
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
||||
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
||||
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
||||
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
||||
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
||||
# Then we can compute the seqlens_q_local, handling the fact that the
|
||||
# first and last blocks could be partial
|
||||
seqlens_q_local = \
|
||||
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
||||
# set the first block since this may be a partial block
|
||||
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
||||
# set the remaining blocks
|
||||
seqlens_q_local[arange > 0] = np.minimum(
|
||||
seqlens_q_local - attn_chunk_size * (arange - 1),
|
||||
attn_chunk_size)[arange > 0]
|
||||
|
||||
# convert from q_seqlens to cu_seqlens_q
|
||||
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
||||
.astype(np.int32)
|
||||
|
||||
# compute the seqlens_k_local,
|
||||
# basically a full local attention block for all but the last block in each
|
||||
# batch
|
||||
# For our example this will be:
|
||||
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
||||
seqlens_k_local = np.full(cu_num_blocks[-1],
|
||||
attn_chunk_size,
|
||||
dtype=np.int32)
|
||||
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
||||
|
||||
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
||||
(rarange * attn_chunk_size + \
|
||||
np.repeat(tokens_in_last_block, local_blocks))
|
||||
# For the example the local attention blocks start at:
|
||||
# _b0_ _____b1_____ _b2_
|
||||
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
||||
block_starts = k_seqstarts_absolute // block_size
|
||||
assert attn_chunk_size % block_size == 0, \
|
||||
f"attn_chunk_size {attn_chunk_size} is not " \
|
||||
f"divisible by block_size {block_size}"
|
||||
pages_per_local_batch = attn_chunk_size // block_size
|
||||
|
||||
# Create a block_table for the local attention blocks
|
||||
# For out example if we have a block-table like (assuming block_size=2):
|
||||
# block_table = [
|
||||
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
||||
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
||||
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
||||
# ]
|
||||
# Then for the local batches we would want a block-table like
|
||||
# block_table_local = [
|
||||
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
||||
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
||||
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
||||
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
||||
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
||||
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
||||
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
||||
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
||||
# ]
|
||||
block_indices= np.broadcast_to(
|
||||
np.arange(pages_per_local_batch, dtype=np.int32),
|
||||
(virtual_batches, pages_per_local_batch)) \
|
||||
+ np.expand_dims(block_starts, axis=1)
|
||||
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
||||
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
||||
local_blocks * pages_per_local_batch)
|
||||
block_table_local = block_table[batch_indices, block_indices]\
|
||||
.view(virtual_batches, -1)
|
||||
|
||||
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
||||
block_table_local
|
||||
Reference in New Issue
Block a user