Files
2026-04-02 04:55:00 +00:00

847 lines
34 KiB
Python

from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeMlaWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
#from vllm.attention.ops.paged_attn import PagedAttention
from vllm_vacc.vllm.attention.ops.vacc_paged_attn import VaccPagedAttention as PagedAttention
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
# import time, os
if TYPE_CHECKING:
from vllm_vacc.vllm.worker.vacc_model_runner import (ModelInputForVACCBuilder,
ModelInputForVACCWithSamplingMetadata)
class VACCMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TORCH_VACC"
@staticmethod
def get_impl_cls() -> Type["VACCMLAImpl"]:
return VACCMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return VACCMLAMetadata
@staticmethod
def get_builder_cls() -> Type["VACCMLAMetadataBuilder"]:
return VACCMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["VACCMLAState"]:
return VACCMLAState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
class VACCMLAState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._positions = torch.zeros((max_batch_size, ),
dtype=torch.long,
device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._positions
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
# max_query_len=1,
# max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"input_positions": attn_metadata.decode_metadata.input_positions,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_positions = attn_metadata.input_positions
num_positions = input_positions.shape[0]
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# CUDA graph buffer is padded so only perform a partial copy based on
# num_positions
input_buffers["input_positions"][:num_positions].copy_(
input_positions, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"VACCMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
return
@dataclass
class VACCMLAMetadata(MLACommonMetadata):
"""Metadata for VACCMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
input_positions: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["VACCMLAMetadata"] = None
_cached_decode_metadata: Optional["VACCMLAMetadata"] = None
num_prefill_tokens: int
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
def __post_init__(self):
supported_head_sizes = VACCMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
@property
def prefill_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = VACCMLAMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=None,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
head_dim=self.head_dim)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["VACCMLAMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = VACCMLAMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=self.seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForVACCWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
# self.num_prefill_tokens = 0
# self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
# assert self.max_query_len == 1
# assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
# self.max_decode_seq_len = None
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class VACCMLAMetadataBuilder(AttentionMetadataBuilder[VACCMLAMetadata]):
def __init__(self, input_builder: "ModelInputForVACCBuilder"):
self.chunked_prefill = True
if hasattr(input_builder, 'chunked_prefill'):
self.chunked_prefill = input_builder.chunked_prefill
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.input_positions: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
self.input_data = self.input_builder.input_data
self.slot_mapping=self.input_data.slot_mapping
self.context_lens= self.input_data.context_lens
if self.input_data.num_prefill_tokens !=0:
self.block_tables = self.input_data.prefill_block_tables
else:
self.block_tables= self.input_data.decode_block_tables
self.input_positions= self.input_data.input_positions
self.prefill_seq_lens = seq_lens[0:self.input_data.num_prefills]
self.num_prefills = self.input_data.num_prefills
self.num_prefill_tokens = self.input_data.num_prefill_tokens
self.num_decode_tokens = self.input_data.num_decode_tokens
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
# max_query_len = max(query_lens)
# decode_query_lens = query_lens[self.num_prefills:]
# if len(decode_query_lens) > 0:
# max_decode_query_len = max(decode_query_lens)
# else:
# max_decode_query_len = 1
# max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
# max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
# assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
input_positions = async_tensor_h2d(self.input_positions, torch.int,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return VACCMLAMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
input_positions=input_positions,
seq_lens_tensor=seq_lens_tensor,
# max_query_len=max_query_len,
# max_decode_query_len=None,
max_prefill_seq_len=None,
max_decode_seq_len=None,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(),
)
class VACCMLAImpl(MLACommonImpl[VACCMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**kwargs) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **kwargs)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"VACCMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"VACCMLAImpl")
def extract_weights(self):
weights = {}
if hasattr(self, 'W_Q'):
weights["W_Q"] = self.W_Q
if hasattr(self, 'W_Q_scales'):
weights["W_Q_scales"] = self.W_Q_scales
if hasattr(self, 'W_QR'):
weights['W_QR'] = self.W_QR
if hasattr(self, 'W_QR_scales'):
weights["W_QR_scales"] = self.W_QR_scales
if hasattr(self, 'W_Q_QR'):
weights["W_Q_QR"] = self.W_Q_QR
if hasattr(self, 'W_Q_QR_scales'):
weights["W_Q_QR_scales"] = self.W_Q_QR_scales
if hasattr(self, 'W_UK'):
weights['W_UK'] = self.W_UK
if hasattr(self, 'W_UK_scales'):
weights['W_UK_scales'] = self.W_UK_scales
if hasattr(self, 'W_Q_UK_scales'):
weights['W_Q_UK_scales'] = self.W_Q_UK_scales
if hasattr(self, 'W_UV'):
weights['W_UV'] = self.W_UV
if hasattr(self, 'W_UV_scales'):
weights['W_UV_scales'] = self.W_UV_scales
if hasattr(self, 'W_UV_O'):
weights['W_UV_O'] = self.W_UV_O
if hasattr(self, 'W_UV_O_scales'):
weights['W_UV_O_scales'] = self.W_UV_O_scales
return weights
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert isinstance(attn_metadata, VACCMLAMetadata)
kv_nope = self.kv_b_proj(kv_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
v = v.contiguous()
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
# value=0)
# attn_output = torch.vacc.scaled_dot_product_attention(
# query=q,
# key=k,
# value=v_padded,
# attn_mask=None,
# dropout_p=0,
# is_causal=True,
# is_train=False,
# recompute=False,
# flash_attention=True,
# sm_scale=self.scale
# )
# attn_output = attn_output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
seq_lens = attn_metadata.prefill_metadata.seq_lens
if len(seq_lens) == 1:
# Vacc supports different head dim of v and qk.
attn_output = torch.vacc.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
attn_out = attn_output.view(-1, self.num_heads * v.shape[-1])
else:
attn_outs = []
start = 0
for seq in seq_lens:
end = start + seq
attn_out = torch.vacc.scaled_dot_product_attention(
query=q[start:end, :],
key=k[start:end, :],
value=v[start:end, :],
attn_mask=None,
dropout_p=0,
is_causal=True,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=self.scale
)
start = end
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0).view(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_out)[0]
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: VACCMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
# print(f"kv_c_and_k_pe_cache: {kv_c_and_k_pe_cache.shape} ")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
# Run MQA using paged_attention
# o = torch.vacc.paged_attention(
# query=q,
# key_cache=kv_c_and_k_pe_cache,
# value_cache=kv_c_cache,
# block_table=decode_meta.block_tables,
# seq_len=decode_meta.seq_lens_tensor,
# out=o,
# sm_scale=self.scale
# )
# Run MQA using spda
# t0 = time.time()
o = vacc_paged_attention_naive(
q,
kv_c_and_k_pe_cache,
kv_c_cache,
block_table = decode_meta.block_tables,
# seq_lens = decode_meta.seq_lens_tensor,
seq_lens=decode_meta.seq_lens,
out = o,
sm_scale=self.scale)
# print(f'{os.getpid()} paged_atten(seq: {decode_meta.seq_lens}) time: {time.time() - t0}')
return self._v_up_proj_and_o_proj(o)
def vacc_paged_attention_naive(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
# seq_lens: torch.Tensor,
seq_lens: int,
out: Optional[torch.Tensor] = None,
sm_scale = -1
) -> torch.Tensor:
# gurantee batch=1 perf
if len(seq_lens) == 1:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[0]]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[0]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query,
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
else:
# t0 = time.time()
attn_outs = []
for i in range(len(seq_lens)):
k_slices = key_cache[block_table[i], :, :, :]
k = torch.cat([k_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
k = k.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens[i]]
v_slices = value_cache[block_table[i], :, :, :]
v = torch.cat([v_slices[i, :, :, :].unsqueeze(1) for i in range(len(block_table[i]))], dim=0)
v = v.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens[i]]
attn_out = torch.vacc.scaled_dot_product_attention(
query=query[i:i+1,:,:],
key=k,
value=v,
attn_mask=None,
dropout_p=0,
is_causal=False,
is_train=False,
recompute=False,
flash_attention=False,
sm_scale=sm_scale
)
attn_outs.append(attn_out)
attn_out = torch.cat(attn_outs, dim=0)
# print(f'{os.getpid()} call spda(seq: {seq_lens}) time: {time.time() - t0}')
return attn_out
# MLA single op impl
def vacc_paged_attention_naive_singleop(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seq_lens,
block_table = None,
out: torch.Tensor = None,
sm_scale = -1
) -> torch.Tensor:
k = key_cache.view(-1, key_cache.shape[2], key_cache.shape[3])[:seq_lens]
v = value_cache.view(-1, value_cache.shape[2], value_cache.shape[3])[:seq_lens].squeeze(1)
pe_cache = k[..., 512:].squeeze(1)
print(f'q:{query[..., :512].shape} v:{v.shape} pe_cache:{pe_cache.shape}')
q_nope_kv_c = torch.einsum("shc,tc->sht", query[..., :512], v)
q_pe_k_pe = torch.einsum("shr,tr->sht", query[..., 512:], pe_cache)
scores = (q_nope_kv_c + q_pe_k_pe) * sm_scale
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(query)
o = torch.einsum("sht,tc->shc", scores, v)
return o