[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

View File

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl,
TorchSDPAMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_input_batch import InputBatch
class TorchSDPABackend:
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "TORCH_SDPA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table
# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
self.num_prompt_req: int = 0
self.seq_start_loc_cpu = torch.zeros(
runner.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
prompt_list_idx = 0
decode_list_idx = 0
for req_index in range(input_batch.num_reqs):
if input_batch.num_computed_tokens_cpu[
req_index] < input_batch.num_prompt_tokens[req_index]:
# prompt stage
self.reorder_prompt_req_index_list[prompt_list_idx] = req_index
prompt_list_idx += 1
else:
# decode stage
self.reorder_decode_req_index_list[decode_list_idx] = req_index
decode_list_idx += 1
assert decode_list_idx + prompt_list_idx == input_batch.num_reqs
# Update prompt requests number
self.num_prompt_req = prompt_list_idx
reorder_req_num = 0
for req_index in range(decode_list_idx):
if self.reorder_decode_req_index_list[req_index] < prompt_list_idx:
reorder_req_num += 1
else:
break
if reorder_req_num == 0:
return False
reorder_prompt_list = (
self.reorder_prompt_req_index_list[:prompt_list_idx]
[-reorder_req_num:])
reorder_decode_list = (
self.reorder_decode_req_index_list[:decode_list_idx]
[:reorder_req_num])
assert reorder_decode_list.size == reorder_prompt_list.size
for idx in range(reorder_req_num):
prompt_req_index = reorder_prompt_list[idx].item()
decode_req_index = reorder_decode_list[idx].item()
input_batch.swap_states(prompt_req_index, decode_req_index)
return True
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
runner = self.runner
block_table = self.block_table
seq_lens_np = runner.seq_lens_np[:num_reqs]
num_prompt_req = self.num_prompt_req
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
) if num_prompt_req > 0 else 0
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
) if num_prompt_req < num_reqs else 0
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
) - num_prefill_tokens
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
block_table_tensor = block_table.get_device_tensor()
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prompt_req,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.
query_start_loc_cpu[:num_prompt_req + 1], # prefill
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
1], # prefill
prefill_block_tables=block_table_tensor[:
num_prompt_req], # prefill
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
1], # for logits index
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)
return attn_metadata

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,657 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashInfer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper)
#MultiLevelCascadeAttentionWrapper
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [64, 128, 256]
@staticmethod
def get_name() -> str:
return "FLASHINFER_VLLM_V1"
@staticmethod
def get_impl_cls() -> type[FlashInferImpl]:
return FlashInferImpl
@staticmethod
def get_metadata_cls() -> type[FlashInferMetadata]:
return FlashInferMetadata
@staticmethod
def get_builder_cls() -> type[FlashInferMetadataBuilder]:
return FlashInferMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
def get_per_layer_parameters(
vllm_config: VllmConfig) -> dict[str, PerLayerParameters]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, FlashInferImpl)
# Infer hyperparameters from the attention layer
window_size = impl.sliding_window
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = impl.logits_soft_cap
sm_scale = impl.scale
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
@dataclass
class FlashInferMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr: torch.Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: torch.Tensor
# The page indices of the paged kv cache
paged_kv_indices: torch.Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: torch.Tensor
# The number of query/output heads
num_qo_heads: int
# The number of key/value heads
num_kv_heads: int
# The dimension of the attention heads
head_dim: int
# Block size of vllm
page_size: int
# The data type of the paged kv cache
data_type: torch.dtype
# The data type of the query
q_data_type: torch.dtype
slot_mapping: torch.Tensor
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
# For cascade attention.
use_cascade: bool
shared_qo_indptr: Optional[torch.Tensor] = None
shared_kv_page_indptr: Optional[torch.Tensor] = None
shared_kv_page_indices: Optional[torch.Tensor] = None
shared_kv_last_page_len: Optional[torch.Tensor] = None
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
# cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
return self.qo_indptr
def __post_init__(self):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = runner.vllm_config
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
# def _get_cascade_wrapper(self):
# if self._cascade_wrapper is None:
# self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
# 2, self._get_workspace_buffer(), "NHD")
# return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config))
if attn_metadata.use_cascade and False: # not supported
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
[
attn_metadata.shared_kv_page_indptr,
attn_metadata.paged_kv_indptr
],
[
attn_metadata.shared_kv_page_indices,
attn_metadata.paged_kv_indices
],
[
attn_metadata.shared_kv_last_page_len,
attn_metadata.paged_kv_last_page_len
],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
assert attn_metadata.paged_kv_last_page_len[
prefill_start:].shape[0] == self._num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr = attn_metadata.qo_indptr[
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr,
attn_metadata.paged_kv_indptr[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
causal=True,
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
if self._num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:self._num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
attn_metadata.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
sm_scale=self.global_hyperparameters.sm_scale,
window_left=self.global_hyperparameters.window_left,
logits_soft_cap=self.global_hyperparameters.
logits_soft_cap,
q_data_type=attn_metadata.q_data_type,
kv_data_type=attn_metadata.data_type,
)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.kv_cache_spec.block_size
device = self.runner.device
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
block_table_bounds = (seq_lens + page_size - 1) // page_size
use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=device)
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
else:
shared_qo_indptr = None
shared_kv_page_indptr = None
shared_kv_page_indices = None
shared_kv_last_page_len = None
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
data_type=self.kv_cache_spec.dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr=shared_qo_indptr,
shared_kv_page_indptr=shared_kv_page_indptr,
shared_kv_page_indices=shared_kv_page_indices,
shared_kv_last_page_len=shared_kv_last_page_len,
)
self._plan(attn_metadata)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
logger.warning_once(
"Using cascade attention in FlashInfer is not supported yet")
return False
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False
return use_cascade_attention(*args, **kwargs)
class FlashInferImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashInferImpl")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashInfer.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
window_left = (self.sliding_window[0]
if self.sliding_window is not None else -1)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
output_padded = output
output = output[:num_actual_tokens]
# if attn_metadata.use_cascade:
# # Cascade attention (rare case).
# assert attn_metadata.cascade_wrapper is not None
# output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
# return output
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if prefill_wrapper := attn_metadata.prefill_wrapper:
prefill_query = query[num_decode_tokens:]
assert prefill_query.shape[0] == num_prefill_tokens
assert prefill_wrapper is not None
assert prefill_wrapper._causal
assert prefill_wrapper._window_left == window_left
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
assert decode_wrapper._window_left == window_left
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
)
return output_padded

View File

@@ -0,0 +1,473 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature,
create_block_mask,
flex_attention)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if current_platform.is_cuda():
pass
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
create_block_mask_compiled = torch.compile(create_block_mask,
fullgraph=True,
mode="reduce-overhead")
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [16, 32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
@staticmethod
def get_impl_cls() -> type["FlexAttentionImpl"]:
return FlexAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlexAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
return FlexAttentionMetadataBuilder
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
# @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(
block_table: torch.Tensor,
total_blocks: Optional[int] = None) -> torch.Tensor:
"""
Creates an inverse mapping from physical block locations to logical indices.
The original block_table maps from logical blocks to physical locations:
Logical to Physical (Original block_table):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
└───────────────────────────────────────────┘
This function creates the inverse mapping:
Physical to Logical (Inverse mapping):
┌───────────────────────────────────────────┐
│ Request 0: │
│ │
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
│ │ │ │ │ │ │ │ │ │
│ v v v v v v v v │
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
└───────────────────────────────────────────┘
If multiple logical blocks map to the same physical block,
this function returns the first (minimum) logical block index.
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations
Returns:
A tensor of shape [max_reqs, max_physical_block]
"""
max_reqs, max_num_blocks = block_table.shape
device = block_table.device
physical_to_logical = torch.full((max_reqs, total_blocks),
-1,
dtype=torch.long,
device=device)
logical_indices = (torch.arange(max_num_blocks,
device=device).unsqueeze(0).expand(
max_reqs, -1))
physical_to_logical.scatter_(-1, block_table.to(torch.int64),
logical_indices)
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1
return physical_to_logical
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
kv_idx: torch.Tensor):
return q_idx >= kv_idx
@dataclass
class FlexAttentionMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# Block info
total_cache_tokens: int
block_size: int
max_possible_sequence_length: int
num_reqs: int
physical_to_logical: torch.Tensor
decode_offset: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# Flex Metadata
num_blocks = 0
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
mask_mod: Optional[_mask_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
def get_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles:
1. The paged attention block mapping
2. The mapping from packed query sequences to logical query entries
It also by defaults adds the decoding offset to the query indices.
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
TODO is_within_lower_bound: do sequences start on block_boundaries?
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
# Map query indices to corresponding request indices
q_req = request_lookup[q_idx]
# Convert physical KV indices to logical indices
physical_kv_block = physical_kv_idx // self.block_size
physical_kv_offset = physical_kv_idx % self.block_size
logical_block_idx = self.physical_to_logical[q_req,
physical_kv_block]
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501
# Determine valid kv indices
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
# Convert physical query indices to logical indices
local_q_idx = q_idx - self.query_start_loc[q_req]
logical_q_idx = local_q_idx + self.decode_offset[q_req]
# Apply mask modification only for valid indices
return torch.where(
is_valid,
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
False,
)
return final_mask_mod
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
self.mask_mod,
None,
None,
self.num_actual_tokens,
self.total_cache_tokens,
)
def __post_init__(self):
assert self.use_cascade is False, "Not implemented yet."
assert self.common_prefix_len == 0, "Not implemented yet."
assert self.cu_prefix_query_lens is None, "Not implemented yet."
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(
AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
if use_cascade:
raise NotImplementedError("Not yet my friend")
block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.runner.model_config.max_model_len
total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
block_size)
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, self.runner.cache_config.num_gpu_blocks)
# Get the original offset tensor
offset_tensor = torch.tensor(
self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to(
self.runner.device, non_blocking=True)
out = FlexAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
block_size=block_size,
max_possible_sequence_length=max_possible_seq_len,
num_reqs=num_reqs,
physical_to_logical=inverse_block_table,
total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor,
)
return out
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]]
alibi_slopes: Optional[torch.Tensor]
logits_soft_cap: Optional[float]
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if blocksparse_params is not None:
# TODO we should support this :think
raise ValueError(
"FlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError(
"FlexAttention does not support alibi slopes yet.")
else:
self.alibi_slopes = None
if sliding_window is not None:
raise NotImplementedError(
"FlexAttention does not support sliding window yet.")
else:
self.sliding_window = (-1, -1)
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
if self.logits_soft_cap is not None:
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if kv_sharing_target_layer_name is not None:
raise NotImplementedError(
"FlexAttention does not support kv sharing yet.")
support_head_sizes = FlexAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet")
@staticmethod
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
"""View a 3d tensor as 4D."""
if tensor.ndim == 4:
return tensor
assert tensor.ndim == 3
return tensor[None, :, :, :]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FLexAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
enable_gqa = self.num_kv_heads != self.num_heads
if attn_metadata is None:
# Profiling run.
return output
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
# return torch.empty_like(query)
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
query, key_cache, value_cache = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key_cache, value_cache),
)
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
out = flex_attention_compiled(
query,
key_cache,
value_cache,
attn_metadata.score_mod,
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
)
# Flex doesn't have an out variant today, rely on epilogue fusion
out = out.permute(0, 2, 1, 3).squeeze(0)
output[:num_actual_tokens, :, :].copy_(out)
return output

View File

@@ -0,0 +1,975 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
# MLA Common Components
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
# from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
is_vllm_fa = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm import envs
logger = init_logger(__name__)
def get_flash_attn_version():
return None
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [576]
@dataclass
class MLACommonPrefillMetadata:
""" Prefill Specific Metadata """
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class MLACommonDecodeMetadata:
block_table: torch.Tensor
seq_lens: torch.Tensor
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass
class MLACommonMetadata(Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# The dimension of the attention heads
head_dim: Optional[int] = None
decode: Optional[D] = None
prefill: Optional[MLACommonPrefillMetadata] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
M = TypeVar("M", bound=MLACommonMetadata)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self,
runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
block_table: BlockTable,
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = False # current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec
# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
)
self.block_table = block_table
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
return self.build(0, m)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
MLACommonPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
workspace=self.chunked_prefill_workspace,
)
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
)
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there is multiple results
lse = None
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight if not envs.MACA_VLLM_USE_TN_2_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank]
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert attn_metadata.prefill is not None
has_context = False # attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
# return_softmax_lse=has_context,
)
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
output = torch.empty_like(suffix_output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]] \
.reshape(-1, self.num_heads * v.shape[-1])
return output
@abstractmethod
def _forward_decode(
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: M,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# The zero fill is required when used with DP + EP
# to ensure all ranks within a DP group compute the
# same expert outputs.
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
B = q_nope.shape[0]
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
)
return self._v_up_proj(o)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
# yapf: enable
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
def _get_paged_kv_tensors(
self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
return (
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
qo_indptr,
)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
(
paged_kv_indices,
paged_kv_indptr,
paged_last_page_len,
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
max_seqlen_qo = attn_metadata.prefill.max_query_len
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.attention.backends.triton_mla import (load_config,
find_best_mla_para)
logger = init_logger(__name__)
import os
# TODO: Configure environment variables temporarily. New versions do not need to be configured
os.environ['TRITON_ENABLE_MACA_OPT_MOVE_DOT_OPERANDS_OUT_LOOP'] = '1'
os.environ['TRITON_ENABLE_MACA_CHAIN_DOT_OPT'] = '1'
JSON_DATA = load_config()
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["TritonMLAMetadata"]:
return TritonMLAMetadata
@staticmethod
def get_builder_cls() -> type["TritonMLAMetadataBuilder"]:
return TritonMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
@dataclass
class TritonMLADecodeMetadata(MLACommonDecodeMetadata):
num_kv_splits: int
num_stages: int
@dataclass
class TritonMLAMetadata(MLACommonMetadata[TritonMLADecodeMetadata]):
pass
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[TritonMLAMetadata]):
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> TritonMLADecodeMetadata:
if seq_lens is not None:
batch = seq_lens.shape[0]
max_seq_len = int(seq_lens.max())
num_kv_splits, num_stages = find_best_mla_para(JSON_DATA, batch, max_seq_len, 8)
else:
num_kv_splits = 4
num_stages = 1
return TritonMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
num_kv_splits=num_kv_splits,
num_stages=num_stages,
)
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
attn_metadata.decode.num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
attn_metadata.decode.num_kv_splits,
attn_metadata.decode.num_stages,
self.scale, PAGE_SIZE)
return self._v_up_proj(o)

View File

@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, Optional
import torch
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * 2, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (1024 * 1024 // 2 //
vllm_config.scheduler_config.max_num_seqs // 4)
min_page_size = cdiv(vllm_config.model_config.max_model_len,
max_num_page_per_req)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
page_size = next_power_of_2(
vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl")
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, kv_cache, slot_mapping)
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
)
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
_, _, num_combined_kv_heads, head_size = kv_cache.shape
num_kv_heads = num_combined_kv_heads // 2
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)

View File

@@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "TRITON_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TritonAttentionImpl"]:
return TritonAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
return TritonAttentionMetadataBuilder
class TritonAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = (-1, -1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonAttentionImpl")
self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
use_prefill_decode_attn = self.force_prefill_decode_attn
num_actual_tokens = attn_metadata.num_actual_tokens
if use_prefill_decode_attn:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if use_prefill_decode_attn:
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
else:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
if not current_platform.is_rocm():
# Skip Q quantization on ROCm, since dequantizing back to
# f32 in the attention kernel is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
return output

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import numpy as np
import torch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
M = TypeVar("M")
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
"""
raise NotImplementedError
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
"""
Can this batch (with given metadata) use CUDA Graphs for attention.
"""
return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
num_sms: int,
) -> bool:
return False
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
This method can reorder the batch if desired by the backend.
:return: Has the batch been reordered (default False).
"""
return False
def validate_kv_sharing_target(current_layer_name, target_layer_name,
static_forward_context):
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} ")
if current_layer_name == target_layer_name:
raise ValueError(error_msg +
"cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg +
"is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[
target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")