[Model] Support DeepSeek-V4
This commit is contained in:
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
278
vllm_mlu/model_executor/layers/indexer.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.compressor import (
|
||||
Compressor,
|
||||
rotate_activation,
|
||||
)
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Indexer(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.n_heads = config.index_n_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_local_heads = config.index_n_heads // self.tp_size
|
||||
self.head_dim = config.index_head_dim
|
||||
self.rope_head_dim = config.rope_head_dim
|
||||
self.index_topk = config.index_topk
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.window_size = config.window_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
)
|
||||
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.bfloat16,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
self.merged_softmax_scale = (self.head_dim ** -0.5) * (self.n_heads ** -0.5)
|
||||
self.compress_ratio = compress_ratio
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
self.rotary_emb = rope
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
self.compressor = Compressor(vllm_config, self.rotary_emb, compress_ratio, self.head_dim, True, f"{prefix}.compressor")
|
||||
|
||||
self.freqs_cis = None
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
k_full: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
):
|
||||
assert attn_metadata.prefill.chunked_context is None, \
|
||||
f"Prefill chunked context is not supported."
|
||||
|
||||
query_start_loc = attn_metadata.prefill.query_start_loc
|
||||
cu_seq_q_lens = query_start_loc
|
||||
cu_seq_k_lens = torch.zeros(
|
||||
context_lens.size(0) + 1, dtype=torch.int32, device=q.device,
|
||||
)
|
||||
torch.cumsum(context_lens, dim=0, out=cu_seq_k_lens[1:])
|
||||
attn_metadata.prefill.query_start_loc
|
||||
seq_lens = torch.diff(cu_seq_k_lens)
|
||||
|
||||
batch_size = seq_lens.shape[0]
|
||||
|
||||
new_block_tables = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_prefill_tokens],
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
q_seq_lens = cu_seq_q_lens[1:]-cu_seq_q_lens[:-1]
|
||||
max_seq_len = q_seq_lens.max().item()
|
||||
batch_size = q_seq_lens.size(0)
|
||||
max_compressed_kv_len = max_seq_len // self.compress_ratio
|
||||
kv_cache_block_table = torch.zeros([batch_size, max_compressed_kv_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
# The layout of linear kv is as follows:
|
||||
# | bs0_origin_kv | bs1_origin_kv | bs0_compressed_kv | bs1_compressed_kv |
|
||||
for i in range(batch_size):
|
||||
start = cu_seq_k_lens[i].item()
|
||||
kv_cache_block_table[i] = torch.arange(
|
||||
start, start + max_compressed_kv_len,
|
||||
dtype=torch.int32,
|
||||
device=q.device,
|
||||
)
|
||||
# offset total origin_kv len
|
||||
kv_cache_block_table = kv_cache_block_table + cu_seq_q_lens[-1]
|
||||
|
||||
# query: (tokens, index_head, index_head_dim)
|
||||
# k_full: (tokens, index_head_dim)
|
||||
# weights: (tokens, index_head, 1)
|
||||
mlu_ops.masked_indexer_select_paged_kv_prefill(
|
||||
query=q,
|
||||
key_value=k_full,
|
||||
weights=weights.unsqueeze(-1),
|
||||
kv_cache_block_table=kv_cache_block_table,
|
||||
cu_seq_q_lens=cu_seq_q_lens,
|
||||
cu_seq_k_lens=cu_seq_k_lens,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=None,
|
||||
)
|
||||
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
block_table = attn_metadata.decode.block_table
|
||||
batch_size = block_table.shape[0]
|
||||
seq_len = x.shape[0] // batch_size
|
||||
q = q.view(batch_size, seq_len, *q.shape[1:])
|
||||
weights = weights.view(batch_size, seq_len, *weights.shape[1:])
|
||||
|
||||
seq_lens = attn_metadata.decode.seq_lens
|
||||
k_block_table = block_table
|
||||
|
||||
seq_len = x.shape[0] // batch_size
|
||||
new_block_tables = torch.empty(
|
||||
[batch_size, seq_len, self.index_topk],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
new_context_lens = torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
|
||||
kv_cache_block_table_offset=torch.empty(
|
||||
[attn_metadata.num_decode_tokens],
|
||||
dtype=torch.int32,
|
||||
device=block_table.device,
|
||||
)
|
||||
kv_cache_block_table_offset.fill_(self.window_size)
|
||||
mlu_ops.masked_indexer_select_paged_kv_decode(
|
||||
query=q,
|
||||
k_cache=k_cache,
|
||||
weights=weights.unsqueeze(-1), # (bsz, seq_q, head_num, 1)
|
||||
kv_cache_block_table=block_table,
|
||||
k_context_lens=seq_lens // self.compress_ratio,
|
||||
k_cache_block_table=k_block_table,
|
||||
index_topk=self.index_topk,
|
||||
kv_cache_block_size=self.block_size,
|
||||
softmax_scale=self.merged_softmax_scale,
|
||||
q_scale=None,
|
||||
k_scale_cache=None,
|
||||
sparse_block_table=new_block_tables,
|
||||
sparse_context_lens=new_context_lens,
|
||||
compress_ratio=self.compress_ratio,
|
||||
kv_cache_block_table_offset=kv_cache_block_table_offset,
|
||||
)
|
||||
|
||||
# [batch, seq_q, index_topk] -> [batch, index_topk]
|
||||
new_block_tables = new_block_tables.squeeze(1)
|
||||
return new_block_tables, new_context_lens
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
indexer_kv_cache: torch.Tensor,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
rd = self.rope_head_dim
|
||||
q = self.wq_b(qr)[0]
|
||||
q = q.unflatten(-1, (self.n_heads, self.head_dim))
|
||||
|
||||
self.rotary_emb(positions, q[..., -rd:], None, only_prefill=False)
|
||||
q_pack = rotate_activation(q)
|
||||
weights_pack = self.weights_proj(x)[0] # (tokens, index_local_head)
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
compressed_kv = self.compressor(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
indexer_kv_cache,
|
||||
0,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
if attn_metadata.prefill:
|
||||
assert compressed_kv is not None and compressed_kv.dim() == 3
|
||||
compressed_kv = compressed_kv.squeeze(-2)
|
||||
compressed_context_lens = query_lens // self.compress_ratio
|
||||
|
||||
prefill_q = q_pack[num_decode_tokens:, ...]
|
||||
prefill_weights = weights_pack[num_decode_tokens:, ...]
|
||||
prefill_block_tables, prefill_context_lens = self.forward_prefill(
|
||||
prefill_q,
|
||||
indexer_kv_cache,
|
||||
prefill_weights,
|
||||
attn_metadata,
|
||||
compressed_kv,
|
||||
compressed_context_lens,
|
||||
)
|
||||
|
||||
if attn_metadata.decode:
|
||||
decode_x = x[:num_decode_tokens, ...]
|
||||
decode_q = q_pack[:num_decode_tokens, ...]
|
||||
decode_weights = weights_pack[attn_metadata.num_prefills:]
|
||||
decode_block_tables, decode_context_lens = self.forward_decode(
|
||||
decode_q,
|
||||
decode_x,
|
||||
indexer_kv_cache,
|
||||
decode_weights,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill and attn_metadata.decode:
|
||||
new_block_tables = torch.cat([prefill_block_tables, decode_block_tables], dim=0)
|
||||
new_context_lens = torch.cat([prefill_context_lens, decode_context_lens], dim=0)
|
||||
elif attn_metadata.prefill:
|
||||
new_block_tables = prefill_block_tables
|
||||
new_context_lens = prefill_context_lens
|
||||
else:
|
||||
new_block_tables = decode_block_tables
|
||||
new_context_lens = decode_context_lens
|
||||
return new_block_tables, new_context_lens
|
||||
Reference in New Issue
Block a user