Files
2026-04-24 09:58:03 +08:00

277 lines
8.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
from typing import Callable
from scipy.linalg import hadamard
import torch
from torch import nn
import torch.nn.functional as F
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
def hadamard_transform_ref(x, scale=1.0):
"""
x: (..., dim)
out: (..., dim)
"""
x_shape = x.shape
dim = x.shape[-1]
x = x.reshape(-1, dim)
log_dim = math.ceil(math.log2(dim))
dim_padded = 2 ** log_dim
if dim != dim_padded:
x = F.pad(x, (0, dim_padded - dim))
out = F.linear(
x,
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
)
out = out * scale
return out[..., :dim].reshape(*x_shape)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
hidden_size = x.size(-1)
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
class Compressor(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
rope,
compress_ratio: int = 4,
head_dim: int = 512,
rotate: bool = False,
prefix: str = "",
**kwargs,):
super().__init__()
config = vllm_config.model_config.hf_config
self.dim = config.dim
self.head_dim = head_dim
self.rope_head_dim =config.rope_head_dim
self.nope_head_dim = head_dim - config.rope_head_dim
self.compress_ratio = compress_ratio
self.overlap = compress_ratio == 4
self.rotate = rotate
coff = 1 + self.overlap
self.norm_eps = config.norm_eps
self.window_size = config.window_size
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
# The first half of dimensions for overlapping compression and second half for normal compression.
self.wkv = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wkv",
)
self.wgate = ReplicatedLinear(
self.dim,
coff * self.head_dim,
bias=False,
quant_config=None,
params_dtype = torch.float32,
prefix=f"{prefix}.wgate",
)
self.norm = RMSNorm(self.head_dim, self.norm_eps)
self.rotary_emb = rope
hf_config = vllm_config.model_config.hf_config
assert hasattr(hf_config, "cached_state_num"), \
f"cached_state_num is not set in hf_config"
cached_state_num = hf_config.cached_state_num
self.register_buffer(
"kv_state",
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
persistent=False,
)
self.register_buffer(
"score_state",
torch.full(
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
float("-inf"),
dtype=torch.float32,
),
persistent=False,
)
self.hadamard_matrix = torch.tensor(
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
def overlap_transform(self, tensor: torch.Tensor, value=0):
# tensor: [b,s,r,2d]
b, s, _, _ = tensor.size()
ratio, d = self.compress_ratio, self.head_dim
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
return new_tensor
def forward_decode(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
mlu_ops.fused_compress_single_kv(
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
position=positions,
ape=self.ape,
kv_state=self.kv_state,
score_state=self.score_state,
gamma=self.norm.weight,
sin=self.rotary_emb.sin_,
cos=self.rotary_emb.cos_,
hadamard_matrix=self.hadamard_matrix,
slot_mapping=compressor_slot_mapping,
kv_cache=kv_cache,
kv_cache_scale=None,
eps=self.norm_eps,
overlap=self.overlap,
rotate=self.rotate,
state_idx=batch_to_kv_state,
)
# Here, return fake compressed_kv.
return None
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
forward_func: Callable = (
self.forward_prefill if common_metadata.is_prefill_only
else self.forward_decode
)
return forward_func(
x,
positions,
attn_metadata,
batch_to_kv_state,
kv_cache,
window_offset,
compressor_slot_mapping,
)
def forward_prefill(
self,
x: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
batch_to_kv_state: torch.Tensor,
kv_cache: torch.Tensor,
window_offset: int,
compressor_slot_mapping: torch.Tensor,
):
common_metadata = get_common_metadata()
seq_lens = common_metadata.seq_lens
query_start_loc = common_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
ratio, overlap = self.compress_ratio, self.overlap
dtype = x.dtype
x = x.float()
kv_pack, _ = self.wkv(x)
score_pack, _ = self.wgate(x)
compress_lens = query_lens // self.compress_ratio
cu_compress_lens = torch.cat([
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
torch.cumsum(compress_lens, dim=0)],
)
compress_positions = []
for i in range(len(seq_lens)):
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
remainder = seqlen % ratio
cutoff = seqlen - remainder
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
positions_ = pos[:cutoff:ratio].contiguous()
compress_positions.append(positions_)
kv_positions = torch.cat(compress_positions, dim=0)
total_compress_len = cu_compress_lens[-1].item()
kv = torch.empty(
[total_compress_len, self.head_dim],
dtype=kv_pack.dtype,
device=kv_pack.device,
)
mlu_ops.fused_compress_multi_kv(
kv = kv_pack,
score = score_pack,
kv_state = self.kv_state,
score_state = self.score_state,
state_batch_idx = batch_to_kv_state,
cu_seqlens = query_start_loc,
ape = self.ape,
max_seqlen = common_metadata.max_query_len,
overlap = overlap,
compressed_kv = kv,
)
if kv.size(0) == 0:
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
kv = self.norm(kv.to(dtype))
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
# use compressed cu_seqlens here, so can not call rotary_emb directly
kv_rope = mlu_ops.rotary_embedding(
kv_rope,
self.rotary_emb.sin_,
self.rotary_emb.cos_,
kv_positions,
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
True, # interleaved
True, # discrete
False,
common_metadata.max_query_len,
)
if self.rotate:
kv = rotate_activation(kv)
mlu_ops.reshape_paged_cache(
kv.unsqueeze(1),
None,
kv_cache,
None,
compressor_slot_mapping,
)
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)