Files
2026-04-24 09:58:03 +08:00

278 lines
9.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group
)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.compressor import (
Compressor,
rotate_activation,
)
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
logger = init_logger(__name__)
class Indexer(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
prefix: str = "",
**kwargs,
):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.n_heads = config.index_n_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.n_local_heads = config.index_n_heads // self.tp_size
self.head_dim = config.index_head_dim
self.rope_head_dim = config.rope_head_dim
self.index_topk = config.index_topk
self.q_lora_rank = config.q_lora_rank
self.window_size = config.window_size
self.block_size = vllm_config.cache_config.block_size
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=None,
prefix=f"{prefix}.wq_b",
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=None,
params_dtype = torch.bfloat16,
prefix=f"{prefix}.weights_proj",
)
self.softmax_scale = self.head_dim ** -0.5
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
self.compress_ratio = compress_ratio
self.max_model_len = vllm_config.model_config.max_model_len
self.rotary_emb = rope
self.tp_group = get_tp_group()
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
self.freqs_cis = None
def forward_prefill(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
k_full: torch.Tensor,
context_lens: torch.Tensor,
):
assert attn_metadata.prefill.chunked_context is None, \
f"Prefill chunked context is not supported."
query_start_loc = attn_metadata.prefill.query_start_loc
cu_seq_q_lens = query_start_loc
cu_seq_k_lens = torch.zeros(
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
)
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
attn_metadata.prefill.query_start_loc
seq_lens = torch.diff(cu_seq_k_lens)
batch_size = seq_lens.shape[0]
new_block_tables = torch.empty(
[attn_metadata.num_prefill_tokens, self.index_topk],
dtype=torch.int32,
device=q.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_prefill_tokens],
dtype=torch.int32,
device=q.device,
)
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
max_seq_len = q_seq_lens.max().item()
batch_size = q_seq_lens.size(0)
max_compressed_kv_len = max_seq_len // self.compress_ratio
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
# The layout of linear kv is as follows:
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
for i in range(batch_size):
start = cu_seq_k_lens[i].item()
kv_cache_block_table[i] = torch.arange(
start, start + max_compressed_kv_len,
dtype=torch.int32,
device=q.device,
)
# offset total origin_kv len
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
# query: (tokens, index_head, index_head_dim)
# k_full: (tokens, index_head_dim)
# weights: (tokens, index_head, 1)
mlu_ops.masked_indexer_select_paged_kv_prefill(
query=q,
key_value=k_full,
weights=weights.unsqueeze(-1),
kv_cache_block_table=kv_cache_block_table,
cu_seq_q_lens=cu_seq_q_lens,
cu_seq_k_lens=cu_seq_k_lens,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=None,
)
return new_block_tables, new_context_lens
def forward_decode(
self,
q: torch.Tensor,
x: torch.Tensor,
k_cache: torch.Tensor,
weights: torch.Tensor,
attn_metadata: AttentionMetadata,
):
block_table = attn_metadata.decode.block_table
batch_size = block_table.shape[0]
seq_len = x.shape[0] // batch_size
q = q.view(batch_size, seq_len, *q.shape[1:])
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
seq_lens = attn_metadata.decode.seq_lens
k_block_table = block_table
seq_len = x.shape[0] // batch_size
new_block_tables = torch.empty(
[batch_size, seq_len, self.index_topk],
dtype=torch.int32,
device=block_table.device,
)
new_context_lens = torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset=torch.empty(
[attn_metadata.num_decode_tokens],
dtype=torch.int32,
device=block_table.device,
)
kv_cache_block_table_offset.fill_(self.window_size)
mlu_ops.masked_indexer_select_paged_kv_decode(
query=q,
k_cache=k_cache,
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
kv_cache_block_table=block_table,
k_context_lens=seq_lens // self.compress_ratio,
k_cache_block_table=k_block_table,
index_topk=self.index_topk,
kv_cache_block_size=self.block_size,
softmax_scale=self.merged_softmax_scale,
q_scale=None,
k_scale_cache=None,
sparse_block_table=new_block_tables,
sparse_context_lens=new_context_lens,
compress_ratio=self.compress_ratio,
kv_cache_block_table_offset=kv_cache_block_table_offset,
)
# [batch, seq_q, index_topk] -> [batch, index_topk]
new_block_tables = new_block_tables.squeeze(1)
return new_block_tables, new_context_lens
def forward(self,
x: torch.Tensor,
qr: torch.Tensor,
positions: torch.Tensor,
offsets: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
indexer_kv_cache: torch.Tensor,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
rd = self.rope_head_dim
q = self.wq_b(qr)[0]
q = q.unflatten(-1, (self.n_heads, self.head_dim))
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
q_pack = rotate_activation(q)
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
num_decode_tokens = attn_metadata.num_decode_tokens
compressed_kv = self.compressor(
x,
positions,
attn_metadata,
batch_to_kv_state,
indexer_kv_cache,
0,
compressor_slot_mapping,
)
if attn_metadata.prefill:
assert compressed_kv is not None and compressed_kv.dim() == 3
compressed_kv = compressed_kv.squeeze(-2)
compressed_context_lens = query_lens // self.compress_ratio
prefill_q = q_pack[num_decode_tokens:, ...]
prefill_weights = weights_pack[num_decode_tokens:, ...]
prefill_block_tables, prefill_context_lens = self.forward_prefill(
prefill_q,
indexer_kv_cache,
prefill_weights,
attn_metadata,
compressed_kv,
compressed_context_lens,
)
if attn_metadata.decode:
decode_x = x[:num_decode_tokens, ...]
decode_q = q_pack[:num_decode_tokens, ...]
decode_weights = weights_pack[attn_metadata.num_prefills:]
decode_block_tables, decode_context_lens = self.forward_decode(
decode_q,
decode_x,
indexer_kv_cache,
decode_weights,
attn_metadata,
)
if attn_metadata.prefill and attn_metadata.decode:
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
elif attn_metadata.prefill:
new_block_tables = prefill_block_tables
new_context_lens = prefill_context_lens
else:
new_block_tables = decode_block_tables
new_context_lens = decode_context_lens
return new_block_tables, new_context_lens