[Model] Support DeepSeek-V4
This commit is contained in:
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
277
vllm_mlu/model_executor/layers/compressor.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import math
|
||||
from typing import Callable
|
||||
from scipy.linalg import hadamard
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||||
|
||||
|
||||
def hadamard_transform_ref(x, scale=1.0):
|
||||
"""
|
||||
x: (..., dim)
|
||||
out: (..., dim)
|
||||
"""
|
||||
x_shape = x.shape
|
||||
dim = x.shape[-1]
|
||||
x = x.reshape(-1, dim)
|
||||
log_dim = math.ceil(math.log2(dim))
|
||||
dim_padded = 2 ** log_dim
|
||||
if dim != dim_padded:
|
||||
x = F.pad(x, (0, dim_padded - dim))
|
||||
out = F.linear(
|
||||
x,
|
||||
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||||
)
|
||||
out = out * scale
|
||||
return out[..., :dim].reshape(*x_shape)
|
||||
|
||||
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dtype == torch.bfloat16
|
||||
hidden_size = x.size(-1)
|
||||
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
|
||||
|
||||
|
||||
class Compressor(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
rope,
|
||||
compress_ratio: int = 4,
|
||||
head_dim: int = 512,
|
||||
rotate: bool = False,
|
||||
prefix: str = "",
|
||||
**kwargs,):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.dim = config.dim
|
||||
self.head_dim = head_dim
|
||||
self.rope_head_dim =config.rope_head_dim
|
||||
self.nope_head_dim = head_dim - config.rope_head_dim
|
||||
self.compress_ratio = compress_ratio
|
||||
self.overlap = compress_ratio == 4
|
||||
self.rotate = rotate
|
||||
coff = 1 + self.overlap
|
||||
self.norm_eps = config.norm_eps
|
||||
self.window_size = config.window_size
|
||||
|
||||
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
||||
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
# The first half of dimensions for overlapping compression and second half for normal compression.
|
||||
|
||||
self.wkv = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wkv",
|
||||
)
|
||||
|
||||
self.wgate = ReplicatedLinear(
|
||||
self.dim,
|
||||
coff * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype = torch.float32,
|
||||
prefix=f"{prefix}.wgate",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(self.head_dim, self.norm_eps)
|
||||
|
||||
self.rotary_emb = rope
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
assert hasattr(hf_config, "cached_state_num"), \
|
||||
f"cached_state_num is not set in hf_config"
|
||||
cached_state_num = hf_config.cached_state_num
|
||||
self.register_buffer(
|
||||
"kv_state",
|
||||
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"score_state",
|
||||
torch.full(
|
||||
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
|
||||
float("-inf"),
|
||||
dtype=torch.float32,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.hadamard_matrix = torch.tensor(
|
||||
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
|
||||
|
||||
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
||||
# tensor: [b,s,r,2d]
|
||||
b, s, _, _ = tensor.size()
|
||||
ratio, d = self.compress_ratio, self.head_dim
|
||||
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
||||
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
||||
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
||||
return new_tensor
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
mlu_ops.fused_compress_single_kv(
|
||||
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||||
position=positions,
|
||||
ape=self.ape,
|
||||
kv_state=self.kv_state,
|
||||
score_state=self.score_state,
|
||||
gamma=self.norm.weight,
|
||||
sin=self.rotary_emb.sin_,
|
||||
cos=self.rotary_emb.cos_,
|
||||
hadamard_matrix=self.hadamard_matrix,
|
||||
slot_mapping=compressor_slot_mapping,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_scale=None,
|
||||
eps=self.norm_eps,
|
||||
overlap=self.overlap,
|
||||
rotate=self.rotate,
|
||||
state_idx=batch_to_kv_state,
|
||||
)
|
||||
|
||||
# Here, return fake compressed_kv.
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
forward_func: Callable = (
|
||||
self.forward_prefill if common_metadata.is_prefill_only
|
||||
else self.forward_decode
|
||||
)
|
||||
return forward_func(
|
||||
x,
|
||||
positions,
|
||||
attn_metadata,
|
||||
batch_to_kv_state,
|
||||
kv_cache,
|
||||
window_offset,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
def forward_prefill(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
batch_to_kv_state: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
window_offset: int,
|
||||
compressor_slot_mapping: torch.Tensor,
|
||||
):
|
||||
common_metadata = get_common_metadata()
|
||||
seq_lens = common_metadata.seq_lens
|
||||
query_start_loc = common_metadata.query_start_loc
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
ratio, overlap = self.compress_ratio, self.overlap
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
kv_pack, _ = self.wkv(x)
|
||||
score_pack, _ = self.wgate(x)
|
||||
|
||||
compress_lens = query_lens // self.compress_ratio
|
||||
cu_compress_lens = torch.cat([
|
||||
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
|
||||
torch.cumsum(compress_lens, dim=0)],
|
||||
)
|
||||
|
||||
compress_positions = []
|
||||
for i in range(len(seq_lens)):
|
||||
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
|
||||
remainder = seqlen % ratio
|
||||
cutoff = seqlen - remainder
|
||||
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
|
||||
positions_ = pos[:cutoff:ratio].contiguous()
|
||||
compress_positions.append(positions_)
|
||||
kv_positions = torch.cat(compress_positions, dim=0)
|
||||
|
||||
|
||||
total_compress_len = cu_compress_lens[-1].item()
|
||||
kv = torch.empty(
|
||||
[total_compress_len, self.head_dim],
|
||||
dtype=kv_pack.dtype,
|
||||
device=kv_pack.device,
|
||||
)
|
||||
|
||||
mlu_ops.fused_compress_multi_kv(
|
||||
kv = kv_pack,
|
||||
score = score_pack,
|
||||
kv_state = self.kv_state,
|
||||
score_state = self.score_state,
|
||||
state_batch_idx = batch_to_kv_state,
|
||||
cu_seqlens = query_start_loc,
|
||||
ape = self.ape,
|
||||
max_seqlen = common_metadata.max_query_len,
|
||||
overlap = overlap,
|
||||
compressed_kv = kv,
|
||||
)
|
||||
|
||||
if kv.size(0) == 0:
|
||||
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
|
||||
|
||||
|
||||
kv = self.norm(kv.to(dtype))
|
||||
|
||||
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
|
||||
# use compressed cu_seqlens here, so can not call rotary_emb directly
|
||||
kv_rope = mlu_ops.rotary_embedding(
|
||||
kv_rope,
|
||||
self.rotary_emb.sin_,
|
||||
self.rotary_emb.cos_,
|
||||
kv_positions,
|
||||
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
|
||||
True, # interleaved
|
||||
True, # discrete
|
||||
False,
|
||||
common_metadata.max_query_len,
|
||||
)
|
||||
|
||||
if self.rotate:
|
||||
kv = rotate_activation(kv)
|
||||
|
||||
mlu_ops.reshape_paged_cache(
|
||||
kv.unsqueeze(1),
|
||||
None,
|
||||
kv_cache,
|
||||
None,
|
||||
compressor_slot_mapping,
|
||||
)
|
||||
|
||||
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)
|
||||
Reference in New Issue
Block a user