436 lines
18 KiB
Python
436 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Optional, Type, Union
|
|
|
|
import torch
|
|
|
|
import vllm._custom_ops as ops
|
|
import vllm.envs as envs
|
|
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
|
MLACommonImpl,
|
|
MLACommonMetadata,
|
|
MLACommonMetadataBuilder,
|
|
MLACommonState)
|
|
from vllm.attention.backends.utils import (compute_slot_mapping,
|
|
compute_slot_mapping_start_idx,
|
|
is_block_tables_empty)
|
|
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
|
|
get_aiter_mla_metadata)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
|
|
|
|
|
def is_aiter_mla_enabled() -> bool:
|
|
return envs.VLLM_ROCM_USE_AITER \
|
|
and envs.VLLM_ROCM_USE_AITER_MLA
|
|
|
|
|
|
class AiterMLABackend(MLACommonBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "ROCM_AITER_MLA"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["AiterMLAImpl"]:
|
|
return AiterMLAImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
|
|
return AiterMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
|
|
return AiterMLAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["AiterMLAState"]:
|
|
return AiterMLAState
|
|
|
|
|
|
@dataclass
|
|
class AiterMLAMetadata(MLACommonMetadata):
|
|
# The following 5 tensors are for current version of AITER MLA
|
|
block_table_bound: Optional[torch.Tensor] = None
|
|
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
|
paged_kv_indptr: Optional[torch.Tensor] = None
|
|
# The page indices of the paged kv cache
|
|
paged_kv_indices: Optional[torch.Tensor] = None
|
|
# The number of entries in the last page of each request in
|
|
# the paged kv cache, shape: [batch_size]
|
|
paged_kv_last_page_lens: Optional[torch.Tensor] = None
|
|
|
|
# This is just to make new AITER MLA API work
|
|
# -- MTP support is not added yet.
|
|
qo_indptr: Optional[torch.Tensor] = None
|
|
|
|
@property
|
|
def prefill_metadata(self):
|
|
prefill_metadata = super().prefill_metadata
|
|
self._cached_prefill_metadata = prefill_metadata
|
|
|
|
if prefill_metadata is not None:
|
|
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
|
|
prefill_metadata.paged_kv_indices = self.paged_kv_indices
|
|
prefill_metadata\
|
|
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
|
prefill_metadata.block_table_bound = self.block_table_bound
|
|
prefill_metadata.qo_indptr = self.qo_indptr
|
|
|
|
# update the cache
|
|
self._cached_prefill_metadata = self.__class__(
|
|
**prefill_metadata.__dict__)
|
|
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self):
|
|
decode_metadata = super().decode_metadata
|
|
|
|
self._cached_decode_metadata = decode_metadata
|
|
|
|
if decode_metadata is not None:
|
|
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
|
|
decode_metadata.paged_kv_indices = self.paged_kv_indices
|
|
decode_metadata\
|
|
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
|
decode_metadata.block_table_bound = self.block_table_bound
|
|
decode_metadata.qo_indptr = self.qo_indptr
|
|
|
|
# update the cache
|
|
self._cached_decode_metadata = self.__class__(
|
|
**decode_metadata.__dict__)
|
|
|
|
return self._cached_decode_metadata
|
|
|
|
def _ops_advance_step(self, num_seqs: int, num_queries: int,
|
|
block_size: int, input_tokens: torch.Tensor,
|
|
sampled_token_ids: torch.Tensor,
|
|
input_positions: torch.Tensor) -> None:
|
|
|
|
ops.advance_step_flashinfer(
|
|
num_seqs=num_seqs,
|
|
num_queries=num_queries,
|
|
block_size=block_size,
|
|
input_tokens=input_tokens,
|
|
sampled_token_ids=sampled_token_ids,
|
|
input_positions=input_positions,
|
|
seq_lens=self.seq_lens_tensor,
|
|
slot_mapping=self.slot_mapping,
|
|
block_tables=self.block_tables,
|
|
paged_kv_indices=self.paged_kv_indices,
|
|
paged_kv_indptr=self.paged_kv_indptr,
|
|
paged_kv_last_page_lens=self.paged_kv_last_page_lens,
|
|
block_table_bound=self.block_table_bound)
|
|
|
|
|
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
|
|
|
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
|
super().__init__(input_builder)
|
|
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
|
|
|
def prepare(self):
|
|
super().prepare()
|
|
self.paged_kv_indices: list[int] = []
|
|
self.paged_kv_indptr: list[int] = [0]
|
|
self.paged_kv_last_page_lens: list[int] = []
|
|
self.total_blocks = 0
|
|
self.qo_indptr: list[int] = [0]
|
|
|
|
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
|
prefix_cache_hit: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
2. block table.
|
|
3. slot mapping.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks):
|
|
self.context_lens.append(context_len)
|
|
if is_prompt:
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
# Compute block table.
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
# only allowing multiple of block_size chunk size.
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
block_table = []
|
|
if prefix_cache_hit:
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
# include the entries for the incoming prefill tokens.
|
|
block_table = block_tables[seq_id]
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
and block_tables is not None):
|
|
if curr_sliding_window_block == 0:
|
|
block_table = block_tables[seq_id]
|
|
else:
|
|
block_table = block_tables[seq_id][
|
|
-curr_sliding_window_block:]
|
|
self.block_tables.append(block_table)
|
|
|
|
# Compute slot mapping.
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
context_len,
|
|
self.sliding_window)
|
|
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
|
seq_len, context_len, start_idx,
|
|
self.block_size, inter_data.block_tables)
|
|
if is_profile_run:
|
|
return
|
|
|
|
# Update paged_kv_* tensors only for non-profile run
|
|
block_table = block_tables[seq_id]
|
|
self._update_paged_kv_tensors(block_table, seq_len)
|
|
|
|
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
|
|
# Get the number of valid blocks based on sequence length.
|
|
# If seq_len = 16, block_size = 16,
|
|
# block_table_bound is 1 with 1 valid block.
|
|
# If seq_len = 15, block_size = 16,
|
|
# block_table_bound is 0 + 1 with 1 valid block.
|
|
self.total_blocks += len(block_table)
|
|
block_table_bound = seq_len // self.block_size + 1 \
|
|
if seq_len % self.block_size != 0 \
|
|
else seq_len // self.block_size
|
|
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
|
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
|
block_table_bound)
|
|
self.qo_indptr.append(self.qo_indptr[-1] + 1)
|
|
|
|
last_page_len = seq_len % self.block_size
|
|
if last_page_len == 0:
|
|
last_page_len = self.block_size
|
|
self.paged_kv_last_page_lens.append(last_page_len)
|
|
|
|
def build(self, seq_lens: list[int], query_lens: list[int],
|
|
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
|
|
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
|
batch_size)
|
|
device = self.runner.device
|
|
use_captured_graph = cuda_graph_pad_size != -1
|
|
|
|
if use_captured_graph:
|
|
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
|
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
|
cuda_graph_pad_size)
|
|
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
|
|
last_qo_indptr = self.qo_indptr[-1]
|
|
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
|
|
|
|
# For current version of AITER MLA
|
|
if len(self.paged_kv_indptr) > 0:
|
|
# extend to the maximum number of blocks as returned by the
|
|
# scheduler
|
|
self.paged_kv_indices.extend(
|
|
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
|
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
|
device=device,
|
|
dtype=torch.int)
|
|
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
|
device=device,
|
|
dtype=torch.int)
|
|
paged_kv_last_page_lens_tensor = torch.tensor(
|
|
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
|
|
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
|
1,
|
|
device=device,
|
|
dtype=torch.int)
|
|
|
|
qo_indptr = torch.tensor(self.qo_indptr,
|
|
device=device,
|
|
dtype=torch.int)
|
|
else:
|
|
paged_kv_indices_tensor = None
|
|
paged_kv_indptr_tensor = None
|
|
paged_kv_last_page_lens_tensor = None
|
|
block_table_bound_tensor = None
|
|
qo_indptr = None
|
|
|
|
metadata.paged_kv_indptr = paged_kv_indptr_tensor
|
|
metadata.paged_kv_indices = paged_kv_indices_tensor
|
|
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
|
|
metadata.block_table_bound = block_table_bound_tensor
|
|
metadata.qo_indptr = qo_indptr
|
|
|
|
return metadata
|
|
|
|
|
|
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
|
|
|
@contextmanager
|
|
def graph_capture(self, max_batch_size: int):
|
|
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
|
|
get_aiter_mla_metadata(
|
|
max_batch_size=max_batch_size,
|
|
block_size=self.runner.block_size,
|
|
max_block_per_batch=\
|
|
self.runner.get_max_block_per_batch(),
|
|
device=self.runner.device)
|
|
self._paged_kv_indices_tensor = kv_indices
|
|
self._paged_kv_indptr_tensor = kv_indptr
|
|
self._paged_kv_last_page_lens_tensor = last_page_lens
|
|
self._qo_indptr_tensor = qo_indptr
|
|
|
|
with super().graph_capture(max_batch_size):
|
|
yield
|
|
|
|
del self._paged_kv_indices_tensor
|
|
del self._paged_kv_indptr_tensor
|
|
del self._paged_kv_last_page_lens_tensor
|
|
del self._qo_indptr_tensor
|
|
|
|
def graph_capture_get_metadata_for_batch(
|
|
self,
|
|
batch_size: int,
|
|
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
|
|
|
|
metadata = super().graph_capture_get_metadata_for_batch(
|
|
batch_size, is_encoder_decoder_model)
|
|
|
|
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
|
|
paged_kv_indices = self._paged_kv_indices_tensor
|
|
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
|
|
batch_size]
|
|
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
|
|
|
|
metadata.paged_kv_indptr = paged_kv_indptr
|
|
metadata.paged_kv_indices = paged_kv_indices
|
|
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
|
|
metadata.qo_indptr = qo_indptr
|
|
|
|
return metadata
|
|
|
|
def get_graph_input_buffers(self,
|
|
attn_metadata: AiterMLAMetadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
input_buffers = super().get_graph_input_buffers(
|
|
attn_metadata, is_encoder_decoder_model)
|
|
input_buffers[
|
|
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
|
|
input_buffers[
|
|
"paged_kv_indices"] = attn_metadata.\
|
|
decode_metadata.paged_kv_indices
|
|
input_buffers[
|
|
"paged_kv_last_page_lens"] = attn_metadata.\
|
|
decode_metadata.paged_kv_last_page_lens
|
|
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
|
|
|
|
return input_buffers
|
|
|
|
def prepare_graph_input_buffers(self,
|
|
input_buffers,
|
|
attn_metadata: AiterMLAMetadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
|
is_encoder_decoder_model)
|
|
|
|
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
|
|
0]
|
|
input_buffers["paged_kv_indptr"].copy_(
|
|
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
|
|
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
|
|
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
|
|
input_buffers["paged_kv_last_page_lens"].copy_(
|
|
attn_metadata.decode_metadata.paged_kv_last_page_lens,
|
|
non_blocking=True)
|
|
input_buffers["qo_indptr"].copy_(
|
|
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
|
|
|
|
|
|
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[list[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
**mla_args) -> None:
|
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
blocksparse_params, logits_soft_cap, attn_type,
|
|
kv_sharing_target_layer_name, **mla_args)
|
|
|
|
unsupported_features = [
|
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
|
]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"Aiter MLA does not support one of the following: "
|
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
|
"logits_soft_cap")
|
|
|
|
from aiter import flash_attn_varlen_func
|
|
self.flash_attn_varlen_func = flash_attn_varlen_func
|
|
|
|
def _flash_attn_varlen_diff_headdims(
|
|
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
|
softmax_scale: float, return_softmax_lse: bool,
|
|
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
|
|
output = self.flash_attn_varlen_func(
|
|
q,
|
|
k,
|
|
v,
|
|
**kwargs,
|
|
)
|
|
|
|
return output
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: AiterMLAMetadata,
|
|
) -> torch.Tensor:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
|
|
decode_meta = attn_metadata.decode_metadata
|
|
assert decode_meta is not None
|
|
B = q_nope.shape[0]
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o = torch.empty(B,
|
|
self.num_heads,
|
|
self.kv_lora_rank,
|
|
dtype=q.dtype,
|
|
device=q.device)
|
|
|
|
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
|
|
|
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
|
attn_metadata.qo_indptr,
|
|
attn_metadata.max_query_len,
|
|
attn_metadata.paged_kv_indptr,
|
|
attn_metadata.paged_kv_indices,
|
|
attn_metadata.paged_kv_last_page_lens)
|
|
|
|
return self._v_up_proj(o)
|