278 lines
9.8 KiB
Python
278 lines
9.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from vllm.attention import AttentionMetadata
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.distributed import (
|
||
|
|
get_tensor_model_parallel_world_size,
|
||
|
|
get_tp_group
|
||
|
|
)
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||
|
|
|
||
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
||
|
|
from vllm_mlu.model_executor.layers.compressor import (
|
||
|
|
Compressor,
|
||
|
|
rotate_activation,
|
||
|
|
)
|
||
|
|
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class Indexer(torch.nn.Module):
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
rope,
|
||
|
|
compress_ratio: int = 4,
|
||
|
|
prefix: str = "",
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
config = vllm_config.model_config.hf_config
|
||
|
|
self.dim = config.dim
|
||
|
|
self.n_heads = config.index_n_heads
|
||
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||
|
|
self.n_local_heads = config.index_n_heads // self.tp_size
|
||
|
|
self.head_dim = config.index_head_dim
|
||
|
|
self.rope_head_dim = config.rope_head_dim
|
||
|
|
self.index_topk = config.index_topk
|
||
|
|
self.q_lora_rank = config.q_lora_rank
|
||
|
|
self.window_size = config.window_size
|
||
|
|
self.block_size = vllm_config.cache_config.block_size
|
||
|
|
|
||
|
|
self.wq_b = ReplicatedLinear(
|
||
|
|
self.q_lora_rank,
|
||
|
|
self.n_heads * self.head_dim,
|
||
|
|
bias=False,
|
||
|
|
quant_config=None,
|
||
|
|
prefix=f"{prefix}.wq_b",
|
||
|
|
)
|
||
|
|
|
||
|
|
self.weights_proj = ReplicatedLinear(
|
||
|
|
self.dim,
|
||
|
|
self.n_heads,
|
||
|
|
bias=False,
|
||
|
|
quant_config=None,
|
||
|
|
params_dtype = torch.bfloat16,
|
||
|
|
prefix=f"{prefix}.weights_proj",
|
||
|
|
)
|
||
|
|
|
||
|
|
self.softmax_scale = self.head_dim ** -0.5
|
||
|
|
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
|
||
|
|
self.compress_ratio = compress_ratio
|
||
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||
|
|
|
||
|
|
self.rotary_emb = rope
|
||
|
|
self.tp_group = get_tp_group()
|
||
|
|
|
||
|
|
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
|
||
|
|
|
||
|
|
self.freqs_cis = None
|
||
|
|
|
||
|
|
def forward_prefill(
|
||
|
|
self,
|
||
|
|
q: torch.Tensor,
|
||
|
|
k_cache: torch.Tensor,
|
||
|
|
weights: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
k_full: torch.Tensor,
|
||
|
|
context_lens: torch.Tensor,
|
||
|
|
):
|
||
|
|
assert attn_metadata.prefill.chunked_context is None, \
|
||
|
|
f"Prefill chunked context is not supported."
|
||
|
|
|
||
|
|
query_start_loc = attn_metadata.prefill.query_start_loc
|
||
|
|
cu_seq_q_lens = query_start_loc
|
||
|
|
cu_seq_k_lens = torch.zeros(
|
||
|
|
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
|
||
|
|
)
|
||
|
|
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
|
||
|
|
attn_metadata.prefill.query_start_loc
|
||
|
|
seq_lens = torch.diff(cu_seq_k_lens)
|
||
|
|
|
||
|
|
batch_size = seq_lens.shape[0]
|
||
|
|
|
||
|
|
new_block_tables = torch.empty(
|
||
|
|
[attn_metadata.num_prefill_tokens, self.index_topk],
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=q.device,
|
||
|
|
)
|
||
|
|
new_context_lens = torch.empty(
|
||
|
|
[attn_metadata.num_prefill_tokens],
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=q.device,
|
||
|
|
)
|
||
|
|
|
||
|
|
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
|
||
|
|
max_seq_len = q_seq_lens.max().item()
|
||
|
|
batch_size = q_seq_lens.size(0)
|
||
|
|
max_compressed_kv_len = max_seq_len // self.compress_ratio
|
||
|
|
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
|
||
|
|
|
||
|
|
# The layout of linear kv is as follows:
|
||
|
|
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
|
||
|
|
for i in range(batch_size):
|
||
|
|
start = cu_seq_k_lens[i].item()
|
||
|
|
kv_cache_block_table[i] = torch.arange(
|
||
|
|
start, start + max_compressed_kv_len,
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=q.device,
|
||
|
|
)
|
||
|
|
# offset total origin_kv len
|
||
|
|
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
|
||
|
|
|
||
|
|
# query: (tokens, index_head, index_head_dim)
|
||
|
|
# k_full: (tokens, index_head_dim)
|
||
|
|
# weights: (tokens, index_head, 1)
|
||
|
|
mlu_ops.masked_indexer_select_paged_kv_prefill(
|
||
|
|
query=q,
|
||
|
|
key_value=k_full,
|
||
|
|
weights=weights.unsqueeze(-1),
|
||
|
|
kv_cache_block_table=kv_cache_block_table,
|
||
|
|
cu_seq_q_lens=cu_seq_q_lens,
|
||
|
|
cu_seq_k_lens=cu_seq_k_lens,
|
||
|
|
index_topk=self.index_topk,
|
||
|
|
kv_cache_block_size=self.block_size,
|
||
|
|
softmax_scale=self.merged_softmax_scale,
|
||
|
|
q_scale=None,
|
||
|
|
k_scale_cache=None,
|
||
|
|
sparse_block_table=new_block_tables,
|
||
|
|
sparse_context_lens=new_context_lens,
|
||
|
|
compress_ratio=self.compress_ratio,
|
||
|
|
kv_cache_block_table_offset=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
return new_block_tables, new_context_lens
|
||
|
|
|
||
|
|
def forward_decode(
|
||
|
|
self,
|
||
|
|
q: torch.Tensor,
|
||
|
|
x: torch.Tensor,
|
||
|
|
k_cache: torch.Tensor,
|
||
|
|
weights: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
):
|
||
|
|
block_table = attn_metadata.decode.block_table
|
||
|
|
batch_size = block_table.shape[0]
|
||
|
|
seq_len = x.shape[0] // batch_size
|
||
|
|
q = q.view(batch_size, seq_len, *q.shape[1:])
|
||
|
|
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
|
||
|
|
|
||
|
|
seq_lens = attn_metadata.decode.seq_lens
|
||
|
|
k_block_table = block_table
|
||
|
|
|
||
|
|
seq_len = x.shape[0] // batch_size
|
||
|
|
new_block_tables = torch.empty(
|
||
|
|
[batch_size, seq_len, self.index_topk],
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=block_table.device,
|
||
|
|
)
|
||
|
|
new_context_lens = torch.empty(
|
||
|
|
[attn_metadata.num_decode_tokens],
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=block_table.device,
|
||
|
|
)
|
||
|
|
|
||
|
|
kv_cache_block_table_offset=torch.empty(
|
||
|
|
[attn_metadata.num_decode_tokens],
|
||
|
|
dtype=torch.int32,
|
||
|
|
device=block_table.device,
|
||
|
|
)
|
||
|
|
kv_cache_block_table_offset.fill_(self.window_size)
|
||
|
|
mlu_ops.masked_indexer_select_paged_kv_decode(
|
||
|
|
query=q,
|
||
|
|
k_cache=k_cache,
|
||
|
|
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
|
||
|
|
kv_cache_block_table=block_table,
|
||
|
|
k_context_lens=seq_lens // self.compress_ratio,
|
||
|
|
k_cache_block_table=k_block_table,
|
||
|
|
index_topk=self.index_topk,
|
||
|
|
kv_cache_block_size=self.block_size,
|
||
|
|
softmax_scale=self.merged_softmax_scale,
|
||
|
|
q_scale=None,
|
||
|
|
k_scale_cache=None,
|
||
|
|
sparse_block_table=new_block_tables,
|
||
|
|
sparse_context_lens=new_context_lens,
|
||
|
|
compress_ratio=self.compress_ratio,
|
||
|
|
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
||
|
|
)
|
||
|
|
|
||
|
|
# [batch, seq_q, index_topk] -> [batch, index_topk]
|
||
|
|
new_block_tables = new_block_tables.squeeze(1)
|
||
|
|
return new_block_tables, new_context_lens
|
||
|
|
|
||
|
|
def forward(self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
qr: torch.Tensor,
|
||
|
|
positions: torch.Tensor,
|
||
|
|
offsets: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
batch_to_kv_state: torch.Tensor,
|
||
|
|
indexer_kv_cache: torch.Tensor,
|
||
|
|
compressor_slot_mapping: torch.Tensor,
|
||
|
|
):
|
||
|
|
common_metadata = get_common_metadata()
|
||
|
|
query_start_loc = common_metadata.query_start_loc
|
||
|
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||
|
|
|
||
|
|
rd = self.rope_head_dim
|
||
|
|
q = self.wq_b(qr)[0]
|
||
|
|
q = q.unflatten(-1, (self.n_heads, self.head_dim))
|
||
|
|
|
||
|
|
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
|
||
|
|
q_pack = rotate_activation(q)
|
||
|
|
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
|
||
|
|
|
||
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||
|
|
compressed_kv = self.compressor(
|
||
|
|
x,
|
||
|
|
positions,
|
||
|
|
attn_metadata,
|
||
|
|
batch_to_kv_state,
|
||
|
|
indexer_kv_cache,
|
||
|
|
0,
|
||
|
|
compressor_slot_mapping,
|
||
|
|
)
|
||
|
|
if attn_metadata.prefill:
|
||
|
|
assert compressed_kv is not None and compressed_kv.dim() == 3
|
||
|
|
compressed_kv = compressed_kv.squeeze(-2)
|
||
|
|
compressed_context_lens = query_lens // self.compress_ratio
|
||
|
|
|
||
|
|
prefill_q = q_pack[num_decode_tokens:, ...]
|
||
|
|
prefill_weights = weights_pack[num_decode_tokens:, ...]
|
||
|
|
prefill_block_tables, prefill_context_lens = self.forward_prefill(
|
||
|
|
prefill_q,
|
||
|
|
indexer_kv_cache,
|
||
|
|
prefill_weights,
|
||
|
|
attn_metadata,
|
||
|
|
compressed_kv,
|
||
|
|
compressed_context_lens,
|
||
|
|
)
|
||
|
|
|
||
|
|
if attn_metadata.decode:
|
||
|
|
decode_x = x[:num_decode_tokens, ...]
|
||
|
|
decode_q = q_pack[:num_decode_tokens, ...]
|
||
|
|
decode_weights = weights_pack[attn_metadata.num_prefills:]
|
||
|
|
decode_block_tables, decode_context_lens = self.forward_decode(
|
||
|
|
decode_q,
|
||
|
|
decode_x,
|
||
|
|
indexer_kv_cache,
|
||
|
|
decode_weights,
|
||
|
|
attn_metadata,
|
||
|
|
)
|
||
|
|
|
||
|
|
if attn_metadata.prefill and attn_metadata.decode:
|
||
|
|
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
|
||
|
|
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
|
||
|
|
elif attn_metadata.prefill:
|
||
|
|
new_block_tables = prefill_block_tables
|
||
|
|
new_context_lens = prefill_context_lens
|
||
|
|
else:
|
||
|
|
new_block_tables = decode_block_tables
|
||
|
|
new_context_lens = decode_context_lens
|
||
|
|
return new_block_tables, new_context_lens
|