277 lines
8.9 KiB
Python
277 lines
8.9 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import math
|
||
|
|
from typing import Callable
|
||
|
|
from scipy.linalg import hadamard
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
from vllm.attention import AttentionMetadata
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||
|
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||
|
|
|
||
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
||
|
|
from vllm_mlu.v1.attention.backends.utils import get_common_metadata
|
||
|
|
|
||
|
|
|
||
|
|
def hadamard_transform_ref(x, scale=1.0):
|
||
|
|
"""
|
||
|
|
x: (..., dim)
|
||
|
|
out: (..., dim)
|
||
|
|
"""
|
||
|
|
x_shape = x.shape
|
||
|
|
dim = x.shape[-1]
|
||
|
|
x = x.reshape(-1, dim)
|
||
|
|
log_dim = math.ceil(math.log2(dim))
|
||
|
|
dim_padded = 2 ** log_dim
|
||
|
|
if dim != dim_padded:
|
||
|
|
x = F.pad(x, (0, dim_padded - dim))
|
||
|
|
out = F.linear(
|
||
|
|
x,
|
||
|
|
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||
|
|
)
|
||
|
|
out = out * scale
|
||
|
|
return out[..., :dim].reshape(*x_shape)
|
||
|
|
|
||
|
|
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||
|
|
assert x.dtype == torch.bfloat16
|
||
|
|
hidden_size = x.size(-1)
|
||
|
|
return hadamard_transform_ref(x, scale=hidden_size ** -0.5)
|
||
|
|
|
||
|
|
|
||
|
|
class Compressor(nn.Module):
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
rope,
|
||
|
|
compress_ratio: int = 4,
|
||
|
|
head_dim: int = 512,
|
||
|
|
rotate: bool = False,
|
||
|
|
prefix: str = "",
|
||
|
|
**kwargs,):
|
||
|
|
super().__init__()
|
||
|
|
config = vllm_config.model_config.hf_config
|
||
|
|
self.dim = config.dim
|
||
|
|
self.head_dim = head_dim
|
||
|
|
self.rope_head_dim =config.rope_head_dim
|
||
|
|
self.nope_head_dim = head_dim - config.rope_head_dim
|
||
|
|
self.compress_ratio = compress_ratio
|
||
|
|
self.overlap = compress_ratio == 4
|
||
|
|
self.rotate = rotate
|
||
|
|
coff = 1 + self.overlap
|
||
|
|
self.norm_eps = config.norm_eps
|
||
|
|
self.window_size = config.window_size
|
||
|
|
|
||
|
|
self.ape = nn.Parameter(torch.empty(compress_ratio, coff * self.head_dim, dtype=torch.float32))
|
||
|
|
# wkv and wgate in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||
|
|
# The first half of dimensions for overlapping compression and second half for normal compression.
|
||
|
|
|
||
|
|
self.wkv = ReplicatedLinear(
|
||
|
|
self.dim,
|
||
|
|
coff * self.head_dim,
|
||
|
|
bias=False,
|
||
|
|
quant_config=None,
|
||
|
|
params_dtype = torch.float32,
|
||
|
|
prefix=f"{prefix}.wkv",
|
||
|
|
)
|
||
|
|
|
||
|
|
self.wgate = ReplicatedLinear(
|
||
|
|
self.dim,
|
||
|
|
coff * self.head_dim,
|
||
|
|
bias=False,
|
||
|
|
quant_config=None,
|
||
|
|
params_dtype = torch.float32,
|
||
|
|
prefix=f"{prefix}.wgate",
|
||
|
|
)
|
||
|
|
|
||
|
|
self.norm = RMSNorm(self.head_dim, self.norm_eps)
|
||
|
|
|
||
|
|
self.rotary_emb = rope
|
||
|
|
|
||
|
|
hf_config = vllm_config.model_config.hf_config
|
||
|
|
assert hasattr(hf_config, "cached_state_num"), \
|
||
|
|
f"cached_state_num is not set in hf_config"
|
||
|
|
cached_state_num = hf_config.cached_state_num
|
||
|
|
self.register_buffer(
|
||
|
|
"kv_state",
|
||
|
|
torch.zeros(cached_state_num, coff * compress_ratio, coff * self.head_dim, dtype=torch.float32),
|
||
|
|
persistent=False,
|
||
|
|
)
|
||
|
|
self.register_buffer(
|
||
|
|
"score_state",
|
||
|
|
torch.full(
|
||
|
|
(cached_state_num, coff * compress_ratio, coff * self.head_dim),
|
||
|
|
float("-inf"),
|
||
|
|
dtype=torch.float32,
|
||
|
|
),
|
||
|
|
persistent=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.hadamard_matrix = torch.tensor(
|
||
|
|
hadamard(self.head_dim, dtype=float), dtype=torch.get_default_dtype(), device="mlu")
|
||
|
|
|
||
|
|
def overlap_transform(self, tensor: torch.Tensor, value=0):
|
||
|
|
# tensor: [b,s,r,2d]
|
||
|
|
b, s, _, _ = tensor.size()
|
||
|
|
ratio, d = self.compress_ratio, self.head_dim
|
||
|
|
new_tensor = tensor.new_full((b, s, 2 * ratio, d), value)
|
||
|
|
new_tensor[:, :, ratio:] = tensor[:, :, :, d:]
|
||
|
|
new_tensor[:, 1:, :ratio] = tensor[:, :-1, :, :d]
|
||
|
|
return new_tensor
|
||
|
|
|
||
|
|
def forward_decode(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
positions: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
batch_to_kv_state: torch.Tensor,
|
||
|
|
kv_cache: torch.Tensor,
|
||
|
|
window_offset: int,
|
||
|
|
compressor_slot_mapping: torch.Tensor,
|
||
|
|
):
|
||
|
|
x = x.float()
|
||
|
|
kv_pack, _ = self.wkv(x)
|
||
|
|
score_pack, _ = self.wgate(x)
|
||
|
|
|
||
|
|
mlu_ops.fused_compress_single_kv(
|
||
|
|
kv=kv_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||
|
|
score=score_pack.unsqueeze(1), # (token, D) -> (B, S, D)
|
||
|
|
position=positions,
|
||
|
|
ape=self.ape,
|
||
|
|
kv_state=self.kv_state,
|
||
|
|
score_state=self.score_state,
|
||
|
|
gamma=self.norm.weight,
|
||
|
|
sin=self.rotary_emb.sin_,
|
||
|
|
cos=self.rotary_emb.cos_,
|
||
|
|
hadamard_matrix=self.hadamard_matrix,
|
||
|
|
slot_mapping=compressor_slot_mapping,
|
||
|
|
kv_cache=kv_cache,
|
||
|
|
kv_cache_scale=None,
|
||
|
|
eps=self.norm_eps,
|
||
|
|
overlap=self.overlap,
|
||
|
|
rotate=self.rotate,
|
||
|
|
state_idx=batch_to_kv_state,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Here, return fake compressed_kv.
|
||
|
|
return None
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
positions: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
batch_to_kv_state: torch.Tensor,
|
||
|
|
kv_cache: torch.Tensor,
|
||
|
|
window_offset: int,
|
||
|
|
compressor_slot_mapping: torch.Tensor,
|
||
|
|
):
|
||
|
|
common_metadata = get_common_metadata()
|
||
|
|
forward_func: Callable = (
|
||
|
|
self.forward_prefill if common_metadata.is_prefill_only
|
||
|
|
else self.forward_decode
|
||
|
|
)
|
||
|
|
return forward_func(
|
||
|
|
x,
|
||
|
|
positions,
|
||
|
|
attn_metadata,
|
||
|
|
batch_to_kv_state,
|
||
|
|
kv_cache,
|
||
|
|
window_offset,
|
||
|
|
compressor_slot_mapping,
|
||
|
|
)
|
||
|
|
|
||
|
|
def forward_prefill(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
positions: torch.Tensor,
|
||
|
|
attn_metadata: AttentionMetadata,
|
||
|
|
batch_to_kv_state: torch.Tensor,
|
||
|
|
kv_cache: torch.Tensor,
|
||
|
|
window_offset: int,
|
||
|
|
compressor_slot_mapping: torch.Tensor,
|
||
|
|
):
|
||
|
|
common_metadata = get_common_metadata()
|
||
|
|
seq_lens = common_metadata.seq_lens
|
||
|
|
query_start_loc = common_metadata.query_start_loc
|
||
|
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||
|
|
|
||
|
|
ratio, overlap = self.compress_ratio, self.overlap
|
||
|
|
dtype = x.dtype
|
||
|
|
x = x.float()
|
||
|
|
kv_pack, _ = self.wkv(x)
|
||
|
|
score_pack, _ = self.wgate(x)
|
||
|
|
|
||
|
|
compress_lens = query_lens // self.compress_ratio
|
||
|
|
cu_compress_lens = torch.cat([
|
||
|
|
torch.tensor([0], dtype=compress_lens.dtype, device=compress_lens.device),
|
||
|
|
torch.cumsum(compress_lens, dim=0)],
|
||
|
|
)
|
||
|
|
|
||
|
|
compress_positions = []
|
||
|
|
for i in range(len(seq_lens)):
|
||
|
|
seqlen = (query_start_loc[i+1] - query_start_loc[i]).item()
|
||
|
|
remainder = seqlen % ratio
|
||
|
|
cutoff = seqlen - remainder
|
||
|
|
pos = positions[query_start_loc[i]: query_start_loc[i+1]]
|
||
|
|
positions_ = pos[:cutoff:ratio].contiguous()
|
||
|
|
compress_positions.append(positions_)
|
||
|
|
kv_positions = torch.cat(compress_positions, dim=0)
|
||
|
|
|
||
|
|
|
||
|
|
total_compress_len = cu_compress_lens[-1].item()
|
||
|
|
kv = torch.empty(
|
||
|
|
[total_compress_len, self.head_dim],
|
||
|
|
dtype=kv_pack.dtype,
|
||
|
|
device=kv_pack.device,
|
||
|
|
)
|
||
|
|
|
||
|
|
mlu_ops.fused_compress_multi_kv(
|
||
|
|
kv = kv_pack,
|
||
|
|
score = score_pack,
|
||
|
|
kv_state = self.kv_state,
|
||
|
|
score_state = self.score_state,
|
||
|
|
state_batch_idx = batch_to_kv_state,
|
||
|
|
cu_seqlens = query_start_loc,
|
||
|
|
ape = self.ape,
|
||
|
|
max_seqlen = common_metadata.max_query_len,
|
||
|
|
overlap = overlap,
|
||
|
|
compressed_kv = kv,
|
||
|
|
)
|
||
|
|
|
||
|
|
if kv.size(0) == 0:
|
||
|
|
return kv.unsqueeze(-2).to(dtype) # (compress_token_num, 1, head_size)
|
||
|
|
|
||
|
|
|
||
|
|
kv = self.norm(kv.to(dtype))
|
||
|
|
|
||
|
|
kv_rope = kv[..., -self.rope_head_dim:].unsqueeze(-2)
|
||
|
|
# use compressed cu_seqlens here, so can not call rotary_emb directly
|
||
|
|
kv_rope = mlu_ops.rotary_embedding(
|
||
|
|
kv_rope,
|
||
|
|
self.rotary_emb.sin_,
|
||
|
|
self.rotary_emb.cos_,
|
||
|
|
kv_positions,
|
||
|
|
torch.tensor([0, kv_positions.size(0)], dtype=torch.int32, device=kv_positions.device), # cu_seqlens
|
||
|
|
True, # interleaved
|
||
|
|
True, # discrete
|
||
|
|
False,
|
||
|
|
common_metadata.max_query_len,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.rotate:
|
||
|
|
kv = rotate_activation(kv)
|
||
|
|
|
||
|
|
mlu_ops.reshape_paged_cache(
|
||
|
|
kv.unsqueeze(1),
|
||
|
|
None,
|
||
|
|
kv_cache,
|
||
|
|
None,
|
||
|
|
compressor_slot_mapping,
|
||
|
|
)
|
||
|
|
|
||
|
|
return kv.unsqueeze(-2) # (compress_token_num, 1, head_size)
|