954 lines
39 KiB
Python
954 lines
39 KiB
Python
|
|
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, ClassVar, Generic, List, Optional, Dict, TypeVar, Type
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
|
AttentionMetadata, AttentionType,
|
|
)
|
|
from vllm.attention.backends.utils import get_mla_dims
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import cdiv, round_down
|
|
from vllm.attention.backends.utils import get_mla_dims
|
|
|
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
|
CommonAttentionMetadata)
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
from vllm.v1.worker.block_table import BlockTable
|
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
|
MLACommonMetadata)
|
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
get_per_layer_parameters,
|
|
infer_global_hyperparameters,
|
|
split_decodes_and_prefills)
|
|
from vllm.v1.attention.backends.mla.common import MLACommonImpl
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
|
|
|
|
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
M = TypeVar("M", bound=MLACommonMetadata)
|
|
|
|
def vacc_paged_attention_naive(
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
block_table: torch.Tensor,
|
|
# seq_lens: torch.Tensor,
|
|
seq_lens: int,
|
|
out: Optional[torch.Tensor] = None,
|
|
sm_scale = -1
|
|
) -> torch.Tensor:
|
|
|
|
# gurantee batch=1 perf
|
|
if len(seq_lens) == 1:
|
|
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
|
|
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
|
|
attn_out = torch.vacc.scaled_dot_product_attention(
|
|
query=query,
|
|
key=k,
|
|
value=v,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=False,
|
|
is_train=False,
|
|
recompute=False,
|
|
flash_attention=False,
|
|
sm_scale=sm_scale
|
|
)
|
|
else:
|
|
# t0 = time.time()
|
|
attn_outs = []
|
|
for i in range(len(seq_lens)):
|
|
k_slices = key_cache[block_table[i], :, :, :]
|
|
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
|
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
|
|
v_slices = value_cache[block_table[i], :, :, :]
|
|
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
|
|
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
|
|
|
|
attn_out = torch.vacc.scaled_dot_product_attention(
|
|
query=query[i:i+1,:,:],
|
|
key=k,
|
|
value=v,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=False,
|
|
is_train=False,
|
|
recompute=False,
|
|
flash_attention=False,
|
|
sm_scale=sm_scale
|
|
)
|
|
attn_outs.append(attn_out)
|
|
|
|
attn_out = torch.cat(attn_outs, dim=0)
|
|
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
|
|
return attn_out
|
|
|
|
|
|
@dataclass
|
|
class MLACommonPrefillMetadata:
|
|
""" Prefill Specific Metadata """
|
|
|
|
@dataclass
|
|
class ChunkedContextMetadata:
|
|
# New for MLA (compared to FlashAttention)
|
|
# For handling chunked prefill
|
|
cu_seq_lens: torch.Tensor
|
|
starts: torch.Tensor
|
|
seq_tot: list[int]
|
|
max_seq_lens: list[int]
|
|
workspace: torch.Tensor
|
|
|
|
block_tables: torch.Tensor #block_table => block_tables 兼容v0
|
|
query_start_loc: torch.Tensor
|
|
# max_query_len: int
|
|
seq_lens: list[int]
|
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
|
|
|
|
|
@dataclass
|
|
class MLACommonDecodeMetadata:
|
|
block_tables: torch.Tensor #block_table => block_tables 兼容v0
|
|
seq_lens: torch.Tensor
|
|
|
|
class VACCMLAMetadata(MLACommonMetadata):
|
|
"""Metadata for VACCMLAMetadata.
|
|
|
|
NOTE: Any python object stored here is not updated when it is
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
"""
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]]
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
|
# in the kv cache. Each block can contain up to block_size tokens.
|
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
|
# captured.
|
|
block_tables: Optional[torch.Tensor]
|
|
|
|
# Whether or not if cuda graph is enabled.
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
|
|
use_cuda_graph: bool
|
|
|
|
# Maximum query length in the batch.
|
|
max_query_len: Optional[int] = None
|
|
input_positions: Optional[torch.Tensor] = None
|
|
# Max number of query tokens among request in the batch.
|
|
max_decode_query_len: Optional[int] = None
|
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor] = None
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor] = None
|
|
|
|
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
|
|
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
|
|
|
|
num_prefill_tokens: int
|
|
|
|
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
|
attn_logits: Optional[torch.Tensor] = None
|
|
req_idx: Optional[torch.Tensor] = None
|
|
|
|
# The dimension of the attention heads
|
|
head_dim: Optional[int] = None
|
|
|
|
def __post_init__(self):
|
|
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
|
|
if self.head_dim is not None and self.head_dim \
|
|
not in supported_head_sizes:
|
|
raise ValueError(
|
|
f"Only {supported_head_sizes} are supported for head_dim,",
|
|
f"received {self.head_dim}.")
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
return self._cached_prefill_metadata
|
|
|
|
assert self.seq_lens is not None
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
query_start_loc = (None if self.query_start_loc is None else
|
|
self.query_start_loc[:self.num_prefills + 1])
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[:self.num_prefills])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
seq_start_loc = (None if self.seq_start_loc is None else
|
|
self.seq_start_loc[:self.num_prefills + 1])
|
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
self.context_lens_tensor[:self.num_prefills])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[:self.num_prefills])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[:self.num_prefill_tokens])
|
|
|
|
self._cached_prefill_metadata = VACCMLAMetadata(
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=self.
|
|
multi_modal_placeholder_index_maps,
|
|
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
|
input_positions=input_positions,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_prefill_seq_len=None,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=False,
|
|
head_dim=self.head_dim)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
return self._cached_decode_metadata
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[self.num_prefills:])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[self.num_prefill_tokens:])
|
|
|
|
self._cached_decode_metadata = VACCMLAMetadata(
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
seq_lens=self.seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_decode_query_len=self.max_decode_query_len,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
# Batch may be composed of prefill|decodes, adjust query start
|
|
# indices to refer to the start of decodes. E.g.
|
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
self.query_start_loc[self.num_prefills])
|
|
if self.query_start_loc is not None else None,
|
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
if self.seq_start_loc is not None else None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
use_cuda_graph=self.use_cuda_graph,
|
|
input_positions=input_positions,
|
|
head_dim=self.head_dim)
|
|
return self._cached_decode_metadata
|
|
|
|
def advance_step(self,
|
|
model_input: "ModelInputForVACCWithSamplingMetadata",
|
|
sampled_token_ids: Optional[torch.Tensor],
|
|
block_size: int,
|
|
num_seqs: int,
|
|
num_queries: int,
|
|
turn_prefills_into_decodes: bool = False):
|
|
"""
|
|
Update metadata in-place to advance one decode step.
|
|
"""
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
if num_seqs != num_queries:
|
|
assert num_seqs > num_queries
|
|
assert self.use_cuda_graph
|
|
|
|
if turn_prefills_into_decodes:
|
|
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
|
# decodes are scheduled together. In the first step, all the
|
|
# prefills turn into decodes. This update reflects that
|
|
# conversion.
|
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
self.num_decode_tokens += self.num_prefills
|
|
self.num_prefills = 0
|
|
# self.num_prefill_tokens = 0
|
|
# self.max_prefill_seq_len = 0
|
|
self.max_query_len = 1
|
|
|
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
else:
|
|
assert self.seq_lens is not None
|
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
|
assert self.num_prefills == 0
|
|
assert self.num_prefill_tokens == 0
|
|
assert self.num_decode_tokens == num_seqs
|
|
assert self.slot_mapping.shape == (num_seqs, )
|
|
|
|
assert self.seq_lens is not None
|
|
assert len(self.seq_lens) == num_seqs
|
|
assert self.seq_lens_tensor is not None
|
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
# assert self.max_query_len == 1
|
|
# assert self.max_prefill_seq_len == 0
|
|
|
|
assert self.query_start_loc is not None
|
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
assert self.seq_start_loc is not None
|
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
|
|
assert self.context_lens_tensor is not None
|
|
assert self.context_lens_tensor.shape == (num_queries, )
|
|
|
|
assert self.block_tables is not None
|
|
assert self.block_tables.shape[0] == num_seqs
|
|
|
|
# Update query lengths. Note that we update only queries and not seqs,
|
|
# since tensors may be padded due to captured cuda graph batch size
|
|
for i in range(num_queries):
|
|
self.seq_lens[i] += 1
|
|
# self.max_decode_seq_len = None
|
|
|
|
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
num_queries=num_queries,
|
|
block_size=block_size,
|
|
input_tokens=model_input.input_tokens,
|
|
sampled_token_ids=sampled_token_ids,
|
|
input_positions=model_input.input_positions,
|
|
seq_lens=self.seq_lens_tensor,
|
|
slot_mapping=self.slot_mapping,
|
|
block_tables=self.block_tables)
|
|
|
|
|
|
|
|
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
# blocksparse_params: Optional[Dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: Optional[str],
|
|
# MLA Specific Arguments
|
|
**kwargs) -> None:
|
|
# super().__init__(num_heads, head_size, scale, num_kv_heads,
|
|
# alibi_slopes, sliding_window, kv_cache_dtype,
|
|
# logits_soft_cap, attn_type,
|
|
# kv_sharing_target_layer_name, **kwargs)
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
# print('kwargs', kwargs) # 手动写老版本的继承
|
|
self.q_lora_rank = kwargs['q_lora_rank'] if 'q_lora_rank' in kwargs else None
|
|
self.kv_lora_rank = kwargs['kv_lora_rank'] if 'kv_lora_rank' in kwargs else None
|
|
self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] if 'qk_nope_head_dim' in kwargs else None
|
|
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
|
|
self.qk_head_dim = kwargs['qk_head_dim'] if 'qk_head_dim' in kwargs else None
|
|
self.v_head_dim = kwargs['v_head_dim'] if 'v_head_dim' in kwargs else None
|
|
self.rotary_emb = kwargs['rotary_emb'] if 'rotary_emb' in kwargs else None
|
|
self.q_proj = kwargs['q_proj'] if 'q_proj' in kwargs else None
|
|
self.kv_b_proj = kwargs['kv_b_proj'] if 'kv_b_proj' in kwargs else None
|
|
self.o_proj = kwargs['o_proj'] if 'o_proj' in kwargs else None
|
|
|
|
unsupported_features = [
|
|
alibi_slopes, sliding_window, logits_soft_cap
|
|
]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"VACCMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
|
"logits_soft_cap")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"VACCMLAImpl")
|
|
|
|
def extract_weights(self):
|
|
weights = {}
|
|
if hasattr(self, 'W_Q'):
|
|
weights["W_Q"] = self.W_Q
|
|
if hasattr(self, 'W_Q_scales'):
|
|
weights["W_Q_scales"] = self.W_Q_scales
|
|
if hasattr(self, 'W_QR'):
|
|
weights['W_QR'] = self.W_QR
|
|
if hasattr(self, 'W_QR_scales'):
|
|
weights["W_QR_scales"] = self.W_QR_scales
|
|
if hasattr(self, 'W_Q_QR'):
|
|
weights["W_Q_QR"] = self.W_Q_QR
|
|
if hasattr(self, 'W_Q_QR_scales'):
|
|
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
|
|
if hasattr(self, 'W_UK'):
|
|
weights['W_UK'] = self.W_UK
|
|
if hasattr(self, 'W_UK_scales'):
|
|
weights['W_UK_scales'] = self.W_UK_scales
|
|
if hasattr(self, 'W_Q_UK_scales'):
|
|
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
|
|
if hasattr(self, 'W_UV'):
|
|
weights['W_UV'] = self.W_UV
|
|
if hasattr(self, 'W_UV_scales'):
|
|
weights['W_UV_scales'] = self.W_UV_scales
|
|
if hasattr(self, 'W_UV_O'):
|
|
weights['W_UV_O'] = self.W_UV_O
|
|
if hasattr(self, 'W_UV_O_scales'):
|
|
weights['W_UV_O_scales'] = self.W_UV_O_scales
|
|
return weights
|
|
|
|
def _forward_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: VACCMLAMetadata,
|
|
) -> torch.Tensor:
|
|
|
|
assert isinstance(attn_metadata, VACCMLAV1Metadata)
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
|
|
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv_nope\
|
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
|
v = v.contiguous()
|
|
|
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
|
# v with 0s to match the qk head dim
|
|
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
|
# value=0)
|
|
# attn_output = torch.vacc.scaled_dot_product_attention(
|
|
# query=q,
|
|
# key=k,
|
|
# value=v_padded,
|
|
# attn_mask=None,
|
|
# dropout_p=0,
|
|
# is_causal=True,
|
|
# is_train=False,
|
|
# recompute=False,
|
|
# flash_attention=True,
|
|
# sm_scale=self.scale
|
|
# )
|
|
|
|
# attn_output = attn_output\
|
|
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
|
# .reshape(-1, self.num_heads * v.shape[-1])
|
|
seq_lens = attn_metadata.prefill_metadata.seq_lens
|
|
if len(seq_lens) == 1:
|
|
# Vacc supports different head dim of v and qk.
|
|
attn_output = torch.vacc.scaled_dot_product_attention(
|
|
query=q,
|
|
key=k,
|
|
value=v,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=True,
|
|
is_train=False,
|
|
recompute=False,
|
|
flash_attention=False,
|
|
sm_scale=self.scale
|
|
)
|
|
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
|
|
else:
|
|
attn_outs = []
|
|
start = 0
|
|
for seq in seq_lens:
|
|
end = start + seq
|
|
attn_out = torch.vacc.scaled_dot_product_attention(
|
|
query=q[start:end, :],
|
|
key=k[start:end, :],
|
|
value=v[start:end, :],
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=True,
|
|
is_train=False,
|
|
recompute=False,
|
|
flash_attention=False,
|
|
sm_scale=self.scale
|
|
)
|
|
start = end
|
|
attn_outs.append(attn_out)
|
|
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
|
|
|
|
return self.o_proj(attn_out)[0]
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: VACCMLAMetadata,
|
|
) -> torch.Tensor:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
|
|
|
decode_meta = attn_metadata.decode_metadata
|
|
assert decode_meta is not None
|
|
B = q_nope.shape[0]
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o = torch.zeros(B,
|
|
self.num_heads,
|
|
self.kv_lora_rank,
|
|
dtype=q.dtype,
|
|
device=q.device)
|
|
|
|
# Add a head dim of 1
|
|
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
|
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
|
|
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
|
|
|
# Run MQA using paged_attention
|
|
# o = torch.vacc.paged_attention(
|
|
# query=q,
|
|
# key_cache=kv_c_and_k_pe_cache,
|
|
# value_cache=kv_c_cache,
|
|
# block_table=decode_meta.block_tables,
|
|
# seq_len=decode_meta.seq_lens_tensor,
|
|
# out=o,
|
|
# sm_scale=self.scale
|
|
# )
|
|
|
|
# Run MQA using spda
|
|
# t0 = time.time()
|
|
o = vacc_paged_attention_naive(
|
|
q,
|
|
kv_c_and_k_pe_cache,
|
|
kv_c_cache,
|
|
block_table = decode_meta.block_tables,
|
|
# seq_lens = decode_meta.seq_lens_tensor,
|
|
seq_lens=decode_meta.seq_lens,
|
|
out = o,
|
|
sm_scale=self.scale)
|
|
|
|
return self._v_up_proj_and_o_proj(o)
|
|
|
|
# patch from MLACommonBackend
|
|
class VACCMLABackend(AttentionBackend):
|
|
|
|
accept_output_buffer: bool = False
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "TRITON_MLA_VLLM_V1"
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
return VACCMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> Type["VACCMLAImpl"]:
|
|
return VACCMLAImpl
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["MLAVaccMetadataBuilder"]:
|
|
return MLAVaccMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int, # assumed to be 1 for MLA
|
|
head_size: int,
|
|
cache_dtype_str: str = "auto",
|
|
) -> tuple[int, ...]:
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [576]
|
|
|
|
@classmethod
|
|
def validate_head_size(cls, head_size: int) -> None:
|
|
supported_head_sizes = cls.get_supported_head_sizes()
|
|
if head_size not in supported_head_sizes:
|
|
attn_type = cls.__name__.removesuffix("Backend")
|
|
raise ValueError(
|
|
f"Head size {head_size} is not supported by {attn_type}. "
|
|
f"Supported head sizes are: {supported_head_sizes}. "
|
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
|
"FlexAttention backend which supports all head sizes.")
|
|
|
|
|
|
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
|
# patch from MLACommonMetadata
|
|
@dataclass
|
|
class VACCMLAV1Metadata(Generic[D]):
|
|
"""Metadata for MLACommon.
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
query_start_loc: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
# For handling prefill decode split
|
|
num_decodes: int
|
|
num_decode_tokens: int
|
|
num_prefills: int
|
|
|
|
num_prefill_tokens: int
|
|
|
|
# The dimension of the attention heads
|
|
head_dim: Optional[int] = None
|
|
|
|
decode_metadata: Optional[D] = None
|
|
prefill_metadata: Optional[MLACommonPrefillMetadata] = None
|
|
|
|
def __post_init__(self):
|
|
if self.head_dim is not None:
|
|
MLACommonBackend.validate_head_size(self.head_dim)
|
|
|
|
|
|
|
|
class MLAVaccMetadataBuilder(AttentionMetadataBuilder[M]):
|
|
"""
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
reorder_batch_threshold: ClassVar[int] = 2
|
|
#TODO 区分 prefill decode 阈值
|
|
def __init__(self,
|
|
# runner: "GPUModelRunner",
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
metadata_cls: Optional[type[M]] = None):
|
|
self._global_hyperparameters = infer_global_hyperparameters(
|
|
get_per_layer_parameters(vllm_config, layer_names,
|
|
MLACommonImpl))
|
|
self.metadata_cls = metadata_cls \
|
|
if metadata_cls is not None else VACCMLAV1Metadata
|
|
self.kv_cache_spec = kv_cache_spec
|
|
scheduler_config = vllm_config.scheduler_config
|
|
self.model_config = vllm_config.model_config
|
|
parallel_config = vllm_config.parallel_config
|
|
cache_config = vllm_config.cache_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.device = device
|
|
# self.runner = runner
|
|
# scheduler_config = runner.scheduler_config
|
|
# model_config = runner.model_config
|
|
# cache_config = runner.cache_config
|
|
self.chunked_prefill_enabled = False
|
|
self.num_heads = self.model_config.get_num_attention_heads(
|
|
parallel_config)
|
|
self.mla_dims = get_mla_dims(self.model_config)
|
|
self.aot_schedule = current_platform.is_cuda()
|
|
self.kv_cache_spec = kv_cache_spec
|
|
|
|
# Dont try to access the runner on AMD
|
|
if self.aot_schedule:
|
|
self.page_size = self.kv_cache_spec.block_size
|
|
|
|
# if self.chunked_prefill_enabled:
|
|
# self.chunked_prefill_workspace_size = min(
|
|
# # Max sure there is enough for 8 full length request or at least
|
|
# # 4 pages of cache per request
|
|
# max(
|
|
# 8 * model_config.max_model_len, 4 *
|
|
# scheduler_config.max_num_seqs * cache_config.block_size),
|
|
# # For long-context models try not to over-allocate limiting
|
|
# # kv-cache space, limiting it to 64k tokens,
|
|
# # which would result in the workspace being:
|
|
# # 2*(576)*(64*1024) = 144mb
|
|
# # (assuming 576 MLA head dim, and fp16)
|
|
# # which would result in up-projected context being
|
|
# # 2*(192*128)*(64*1024) = 3gb
|
|
# # (assuming 192 QK head dim, 128 heads, and fp16)
|
|
# 128 * 1024)
|
|
# assert self.chunked_prefill_workspace_size >= \
|
|
# scheduler_config.max_num_seqs * cache_config.block_size
|
|
# self.chunked_prefill_workspace = torch.empty(
|
|
# (self.chunked_prefill_workspace_size,
|
|
# model_config.get_head_size()),
|
|
# dtype=model_config.dtype,
|
|
# device=runner.device,
|
|
# )
|
|
# self.block_table = block_table
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
# We now want to reorder the batch so that the "decode" requests are and
|
|
# the front and the "prefill" requests are at the using the least amount
|
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
|
# where attention is likely memory-bound and "prefill" to mean requests
|
|
# where attention is likely compute-bound, TODO(lucas): figure out a
|
|
# better naming here)
|
|
decodes = []
|
|
prefills = []
|
|
num_decode_tokens = 0
|
|
num_prefill_tokens = 0
|
|
|
|
for i, req_id in enumerate(input_batch.req_ids):
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
# for now treat 1 scheduled token as "decode" even if its not,
|
|
# we should update this to something like < 8 in the future but
|
|
# currently the TritonMLA._forward_decode only supports
|
|
# num_tokens = 1
|
|
if scheduler_output.scheduled_cached_reqs.num_computed_tokens != []:
|
|
decodes.append(i)
|
|
num_decode_tokens += num_tokens
|
|
else:
|
|
prefills.append(i)
|
|
num_prefill_tokens += num_tokens
|
|
|
|
# We hope that this is fairly minimal since decodes
|
|
# should be around for a number of iterations so hopefully they are
|
|
# relatively stationary (and new request are generally appended to the
|
|
# persistent batch so already should be at the back)
|
|
# To achieve this we loop over the decodes in descending order and
|
|
# the prefills in ascending order. We swap decodes from the "back"
|
|
# i.e. past where the last decode should be in the reodorered with
|
|
# prefills from the front of the batch.
|
|
# `decodes` and `prefills` are already in ascending order just based on
|
|
# the above loop
|
|
num_decodes = len(decodes)
|
|
num_prefills = len(prefills)
|
|
modified_batch = False
|
|
|
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
|
# If the decode is at the "back" of the batch, i, we can swap it
|
|
# with the prefill closest to the front of the batch
|
|
decode_idx = decodes[num_decodes - i]
|
|
if decode_idx < num_decodes:
|
|
break
|
|
|
|
input_batch.swap_states(prefills[i - 1], decode_idx)
|
|
modified_batch = True
|
|
|
|
# Save for next `build` call
|
|
# TODO(lucas): this is a bit of a hack, we should probably have a
|
|
# better way of doing this
|
|
self._num_decodes = num_decodes
|
|
self._num_prefills = num_prefills
|
|
self._num_decode_tokens = num_decode_tokens
|
|
self._num_prefill_tokens = num_prefill_tokens
|
|
|
|
return modified_batch
|
|
|
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
|
seq_lens: torch.Tensor):
|
|
return MLACommonDecodeMetadata(
|
|
block_tables=block_table_tensor,
|
|
seq_lens=seq_lens,
|
|
)
|
|
|
|
# def build_for_cudagraph_capture(
|
|
# self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
# """
|
|
# This method builds the metadata for full cudagraph capture.
|
|
# Currently, only decode is supported for full cudagraphs with MLA.
|
|
# """
|
|
# m = common_attn_metadata
|
|
# assert m.num_reqs == m.num_actual_tokens, \
|
|
# "MLA only supports decode-only full CUDAGraph capture. " \
|
|
# "Make sure all cudagraph capture sizes <= max_num_seq."
|
|
|
|
# # m.max_query_len = 1 # decode-only
|
|
|
|
# # Update state usually set in reorder_batch.
|
|
# self._num_decodes = m.num_reqs
|
|
# self._num_decode_tokens = m.num_actual_tokens
|
|
# self._num_prefills = 0
|
|
# self._num_prefill_tokens = 0
|
|
# return self.build(0, m)
|
|
|
|
def append_seqlen(self, seq_len: list[int], all_len: int):
|
|
# print('append_seqlen seq_len', seq_len)
|
|
# print('append_seqlen all_len', all_len)
|
|
if all_len > len(seq_len) and all_len % len(seq_len) == 0:
|
|
new_seq_len = []
|
|
mtp_num = all_len // len(seq_len)
|
|
for start_len in seq_len:
|
|
for i in range(1,1+mtp_num):
|
|
new_seq_len.append(start_len-mtp_num+i)
|
|
return new_seq_len
|
|
return seq_len
|
|
|
|
def build(self, common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> M:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
# max_query_len = common_attn_metadata.max_query_len
|
|
|
|
# assert self._num_decodes + self._num_prefills == num_reqs
|
|
|
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
|
# function. We should avoid GPU -> CPU sync as much as possible because
|
|
# it blocks on all previous kernels.
|
|
# device = self.runner.device
|
|
# block_table = self.block_table
|
|
# block_table_tensor = block_table.get_device_tensor()#[:num_reqs]
|
|
|
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
|
slot_mapping = common_attn_metadata.slot_mapping
|
|
#decode common_attn_metadata: CommonAttentionMetadata(query_start_loc=tensor([0, 2], device='vacc:20', dtype=torch.int32), query_start_loc_cpu=tensor([0, 2], dtype=torch.int32), seq_lens=[40], seq_lens_cpu=[40, 0, 0, 0], num_computed_tokens_cpu=tensor([38], dtype=torch.int32), num_reqs=1, num_actual_tokens=2, max_query_len=2, max_seq_len=40, block_table_tensor=tensor([[1536, 1537, 1538, ..., 0, 0, 0]], device='vacc:20',
|
|
|
|
# block_table.slot_mapping[:num_actual_tokens].copy_(
|
|
# block_table.slot_mapping_cpu[:num_actual_tokens],
|
|
# non_blocking=True)
|
|
# # block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
|
# slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
|
# num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
|
# split_decodes_and_prefills(common_attn_metadata,
|
|
# decode_threshold=self.reorder_batch_threshold)
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = 0, 0, 0, 0
|
|
token_nums = common_attn_metadata.query_start_loc_cpu[1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
|
if token_nums.max().item() > self.reorder_batch_threshold:
|
|
num_prefills = num_reqs
|
|
num_prefill_tokens = num_actual_tokens
|
|
else:
|
|
num_decodes = num_reqs
|
|
num_decode_tokens = num_actual_tokens
|
|
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
|
|
prefill_metadata = None
|
|
if num_prefills > 0:
|
|
reqs_start = num_decodes # prefill_start
|
|
|
|
# context_lens_cpu = self.runner.input_batch.\
|
|
# num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
|
# max_context_len_cpu = context_lens_cpu.max().item()
|
|
# num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
|
prefill_query_start_loc = query_start_loc[
|
|
reqs_start:] - query_start_loc[reqs_start]
|
|
|
|
chunked_context_metadata = None
|
|
# if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
|
# and max_context_len_cpu > 0:
|
|
# dont support chunked prefill
|
|
|
|
|
|
prefill_metadata = MLACommonPrefillMetadata(
|
|
block_tables=block_table_tensor[reqs_start:reqs_start+num_prefills, ...],
|
|
query_start_loc=prefill_query_start_loc,
|
|
# max_query_len=None,
|
|
chunked_context=chunked_context_metadata,
|
|
seq_lens=seq_lens[:num_prefills],
|
|
)
|
|
|
|
if not isinstance(seq_lens, list):
|
|
# TODO init set list in init: vllm/v1/spec_decode/eagle.py
|
|
seq_lens = seq_lens.tolist()
|
|
decode_metadata = None
|
|
if num_decodes > 0:
|
|
block_num_per_group = env_blk_grp_size // 16
|
|
block_table_tensor_new = block_table_tensor[:num_decodes, ::block_num_per_group].contiguous()
|
|
|
|
seq_lens_new = self.append_seqlen(seq_lens[:slot_mapping.shape[-1]], slot_mapping.shape[-1])
|
|
|
|
if slot_mapping.shape[-1] > num_decodes:
|
|
mtp_numbers = [query_start_loc[i+1]-query_start_loc[i] for i in range(len(query_start_loc)-1)] #query_start_loc[1:] - query_start_loc[:-1]
|
|
|
|
block_table_tensor_list = []
|
|
for bi,mtp_number in enumerate(mtp_numbers):
|
|
for _ in range(mtp_number):
|
|
block_table_tensor_list.append(block_table_tensor_new[bi:bi+1])
|
|
block_table_tensor_new = torch.concatenate(block_table_tensor_list, 0)
|
|
|
|
decode_metadata = self._build_decode(
|
|
block_table_tensor=block_table_tensor_new,
|
|
seq_lens=seq_lens_new,
|
|
)
|
|
|
|
return self.metadata_cls(
|
|
num_actual_tokens=num_actual_tokens,
|
|
query_start_loc=query_start_loc,
|
|
slot_mapping=slot_mapping,
|
|
head_dim=self.model_config.get_head_size(),
|
|
# prefill_seq_lens=seq_lens[:self._num_prefills].tolist(), # device to host, todo optimiz
|
|
# MLACommonMetadata Chunk prefill specific
|
|
num_decodes=num_decodes,
|
|
num_decode_tokens=num_decode_tokens,
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
prefill_metadata=prefill_metadata,
|
|
decode_metadata=decode_metadata,
|
|
)
|
|
|