925 lines
36 KiB
Python
925 lines
36 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
from typing import Any, Iterable, Optional, Union
|
|
|
|
import torch
|
|
from fastcore.basics import patch_to
|
|
from torch import nn
|
|
from transformers import DeepseekV2Config, DeepseekV3Config
|
|
|
|
import vllm
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
per_token_group_quant_fp8)
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.models.deepseek_v2 import (
|
|
DeepseekV2ForCausalLM, DeepseekV2Model, FusedMoE, Indexer, PPMissingLayer,
|
|
default_weight_loader, get_spec_layer_idx_from_weight_name,
|
|
is_pp_missing_parameter, maybe_prefix, maybe_remap_kv_scale_name,
|
|
yarn_get_mscale)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import cdiv
|
|
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
|
from vllm_br.v1.attention.backends.mla.indexer import (
|
|
SupaDeepseekV32IndexerBackend)
|
|
from .supa_module import (DeepseekV2MoE, MergedGateUpMLPSiluL2, SupaMLAModules,
|
|
SupaMultiHeadLatentAttention)
|
|
|
|
|
|
@patch_to(vllm.model_executor.models.deepseek_v2.DeepseekV32IndexerCache)
|
|
def get_attn_backend(self) -> AttentionBackend:
|
|
return SupaDeepseekV32IndexerBackend
|
|
|
|
|
|
class SupaDeepseekV2MLAAttention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
config: Union[DeepseekV2Config, DeepseekV3Config],
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
topk_indices_buffer: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.is_v32 = hasattr(config, "index_topk")
|
|
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
|
|
self.num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert num_heads % tp_size == 0
|
|
self.num_local_heads = num_heads // tp_size
|
|
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.fused_qkv_a_proj = None
|
|
self.kv_a_proj_with_mqa = None
|
|
self.q_a_proj = None
|
|
self.q_a_layernorm = None
|
|
self.q_b_proj = None
|
|
self.q_proj = None
|
|
if self.is_v32:
|
|
if self.q_lora_rank is not None:
|
|
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
|
self.hidden_size, [
|
|
self.q_lora_rank,
|
|
self.kv_lora_rank + self.qk_rope_head_dim
|
|
],
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fused_qkv_a_proj",
|
|
disable_tp=True)
|
|
self.fused_qkv_a_proj.no_need_cross = True
|
|
|
|
else:
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
|
|
|
else:
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_a_proj")
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
|
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
|
eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
|
self.num_heads *
|
|
self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_b_proj")
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
|
self.num_heads *
|
|
self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj")
|
|
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
|
eps=config.rms_norm_eps)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_b_proj")
|
|
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj")
|
|
|
|
if rope_scaling:
|
|
if self.is_v32:
|
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
|
else:
|
|
rope_scaling["rope_type"] = 'deepseek_yarn_supa'
|
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
is_neox_style=False)
|
|
if rope_scaling:
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
|
scaling_factor = rope_scaling["factor"]
|
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
|
self.scaling = self.scaling * mscale * mscale
|
|
|
|
if self.is_v32:
|
|
self.indexer: Optional[SupaIndexer] = SupaIndexer(
|
|
vllm_config, config, hidden_size, q_lora_rank, quant_config,
|
|
cache_config, topk_indices_buffer, f"{prefix}.indexer")
|
|
else:
|
|
self.indexer: Optional[SupaIndexer] = None
|
|
|
|
mla_modules = SupaMLAModules(
|
|
kv_a_layernorm=self.kv_a_layernorm,
|
|
kv_b_proj=self.kv_b_proj,
|
|
rotary_emb=self.rotary_emb,
|
|
o_proj=self.o_proj,
|
|
fused_qkv_a_proj=self.fused_qkv_a_proj,
|
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
|
q_a_layernorm=self.q_a_layernorm,
|
|
q_b_proj=self.q_b_proj,
|
|
q_proj=self.q_proj,
|
|
indexer=self.indexer,
|
|
is_sparse=self.is_v32,
|
|
topk_indices_buffer=topk_indices_buffer,
|
|
q_a_proj=self.q_a_proj,
|
|
)
|
|
|
|
self.mla_attn = SupaMultiHeadLatentAttention(
|
|
self.hidden_size,
|
|
self.num_local_heads,
|
|
self.scaling,
|
|
self.qk_nope_head_dim,
|
|
self.qk_rope_head_dim,
|
|
self.v_head_dim,
|
|
self.q_lora_rank,
|
|
self.kv_lora_rank,
|
|
mla_modules,
|
|
cache_config,
|
|
quant_config,
|
|
prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self.mla_attn(positions, hidden_states, is_ds_v32=self.is_v32)
|
|
|
|
|
|
def indexer_k_cache(
|
|
k: torch.Tensor, # [num_tokens, head_dim] # (8, 128)
|
|
kv_cache: torch.
|
|
Tensor, # [1, num_blocks, block_size, cache_stride] # (1, 1024, 2048, 128)
|
|
slot_mapping: torch.Tensor, # [num_tokens] # (8)
|
|
) -> None:
|
|
num_tokens = k.shape[0]
|
|
head_dim = k.shape[1]
|
|
|
|
# [TODO] kv_cache shape is not aligned with nv
|
|
cache_block_size = kv_cache.shape[-2]
|
|
|
|
for idx in range(num_tokens):
|
|
slot_idx = slot_mapping[idx]
|
|
k_idx = k[idx]
|
|
block_idx = slot_idx // cache_block_size
|
|
block_offset = slot_idx % cache_block_size
|
|
kv_cache[0][block_idx][
|
|
block_offset][:
|
|
head_dim] = k_idx # [TODO] kv cache stride is longer than head_dim
|
|
|
|
|
|
def bf16_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
cu_seqlen_ks: torch.Tensor,
|
|
cu_seqlen_ke: torch.Tensor,
|
|
):
|
|
seq_len_kv = kv.shape[0]
|
|
|
|
k = kv
|
|
q = q.float()
|
|
k = k.float()
|
|
|
|
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
|
>= cu_seqlen_ks[:, None])
|
|
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
|
< cu_seqlen_ke[:, None])
|
|
|
|
mask = mask_lo & mask_hi
|
|
score = torch.einsum("mhd,nd->hmn", q, k)
|
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
|
logits = logits.masked_fill(~mask, float("-inf"))
|
|
|
|
return logits
|
|
|
|
|
|
def _ref_fp8_paged_mqa_logits(
|
|
q: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
context_lens: torch.Tensor,
|
|
block_tables: torch.Tensor,
|
|
max_model_len: int,
|
|
):
|
|
batch_size, next_n, _, _ = q.size()
|
|
_, num_block, block_size, unkonw_size, head_dim = kv_cache.size(
|
|
) # [1, num_block, block_size, _]
|
|
num_block = num_block * 16
|
|
block_size = block_size // 16
|
|
kv_cache = kv_cache.view(num_block, block_size, unkonw_size, head_dim)
|
|
logits = torch.full(
|
|
[batch_size * next_n, max_model_len],
|
|
float("-inf"),
|
|
device=q.device,
|
|
dtype=torch.float32,
|
|
)
|
|
context_lens_list = context_lens.tolist()
|
|
for i in range(batch_size):
|
|
context_len = context_lens_list[i]
|
|
q_offsets = torch.arange(context_len - next_n,
|
|
context_len,
|
|
device="cuda")
|
|
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
|
0, 1).contiguous())
|
|
for block_rk in range(cdiv(context_len, block_size)):
|
|
block_idx = block_tables[i][block_rk]
|
|
qx, kx = q[i], kv_cache[block_idx]
|
|
k_offsets = torch.arange(
|
|
block_rk * block_size,
|
|
(block_rk + 1) * block_size,
|
|
device="cuda",
|
|
)
|
|
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
|
<= q_offsets[:, None])
|
|
s = torch.where(
|
|
mask[None, :, :],
|
|
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
|
logits.dtype),
|
|
float("-inf"),
|
|
)
|
|
s = torch.relu(s) * weight_slice[..., None]
|
|
s = s.sum(dim=0)
|
|
logits[
|
|
i * next_n:(i + 1) * next_n,
|
|
block_rk * block_size:(block_rk + 1) * block_size,
|
|
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
|
float("-inf"))
|
|
return logits
|
|
|
|
|
|
def cp_gather_indexer_k_quant_cache(
|
|
kv_cache, # [1, num_blocks, block_size, head_dim + 1]
|
|
dst_value, # [cu_seq_lens[-1], head_dim]
|
|
dst_scale, # [cu_seq_lens[-1], 4]
|
|
block_table, # [batch_size, num_blocks]
|
|
cu_seq_lens, # [batch_size + 1, ]
|
|
batch_size,
|
|
):
|
|
_, num_blocks, block_size, _ = kv_cache.shape
|
|
# align to nv
|
|
num_blocks = num_blocks * 16
|
|
block_size = block_size // 16
|
|
head_dim = dst_value.shape[-1]
|
|
kv_cache = kv_cache.view(num_blocks, -1)
|
|
|
|
expected_value = []
|
|
# expected_scale = []
|
|
for b in range(batch_size):
|
|
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
|
if s == 0:
|
|
continue
|
|
tot = cdiv(s, block_size)
|
|
blocks = block_table[b, :tot]
|
|
|
|
value = []
|
|
full_block = torch.arange(tot - 1,
|
|
device=kv_cache.device,
|
|
dtype=torch.int32)
|
|
# [TODO] not support index in tensor on br, run in cpu now
|
|
non_remaining_value = kv_cache.cpu()[
|
|
blocks.cpu()[full_block.cpu()], :block_size * head_dim].view(
|
|
-1, head_dim)
|
|
# non_remaining_scale = kv_cache[blocks[full_block],
|
|
# block_size * head_dim:].view(-1, 4)
|
|
|
|
remaining = s - (tot - 1) * block_size
|
|
|
|
value = torch.cat([
|
|
non_remaining_value,
|
|
kv_cache.cpu()[blocks[-1], :remaining * head_dim].view(
|
|
-1, head_dim)
|
|
],
|
|
dim=0)
|
|
# scale = torch.cat([
|
|
# non_remaining_scale,
|
|
# kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
|
# remaining * 4].view(-1, 4)
|
|
# ],
|
|
# dim=0)
|
|
|
|
expected_value.append(value)
|
|
# expected_scale.append(scale)
|
|
|
|
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
|
# gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
|
gather_value = gather_value.view(torch.bfloat16).to(dst_value.device)
|
|
# gather_scale = gather_scale.view(torch.float32)
|
|
dst_value.copy_(gather_value)
|
|
# dst_scale.copy_(gather_scale)
|
|
|
|
|
|
def sparse_attn_indexer_fake(
|
|
hidden_states: torch.Tensor,
|
|
k_cache_prefix: str,
|
|
kv_cache: torch.Tensor,
|
|
q_fp8: torch.Tensor,
|
|
k: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
quant_block_size: int,
|
|
scale_fmt: Optional[str],
|
|
topk_tokens: int,
|
|
head_dim: int,
|
|
max_model_len: int,
|
|
total_seq_lens: int,
|
|
topk_indices_buffer: Optional[torch.Tensor],
|
|
) -> torch.Tensor:
|
|
# profile run
|
|
# NOTE(Chen): create the max possible flattened_kv. So that
|
|
# profile_run can get correct memory usage.
|
|
support_fp8 = False
|
|
if support_fp8:
|
|
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
|
device=k.device,
|
|
dtype=torch.uint8)
|
|
_k_fp8 = _flattened_kv[..., :head_dim].view(
|
|
torch.float8_e4m3fn).contiguous()
|
|
_k_scale = _flattened_kv[...,
|
|
head_dim:].view(torch.float32).contiguous()
|
|
else:
|
|
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
|
device=k.device,
|
|
dtype=torch.bfloat16)
|
|
return topk_indices_buffer
|
|
|
|
|
|
def sparse_attn_indexer(
|
|
hidden_states: torch.Tensor,
|
|
k_cache_prefix: str,
|
|
kv_cache: torch.Tensor,
|
|
q_fp8: torch.Tensor,
|
|
k: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
quant_block_size: int,
|
|
scale_fmt: Optional[str],
|
|
topk_tokens: int,
|
|
head_dim: int,
|
|
max_model_len: int,
|
|
total_seq_lens: int,
|
|
topk_indices_buffer: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
# careful! this will be None in dummy run
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
# assert isinstance(attn_metadata, dict)
|
|
if not isinstance(attn_metadata, dict):
|
|
return sparse_attn_indexer_fake(
|
|
hidden_states,
|
|
k_cache_prefix,
|
|
kv_cache,
|
|
q_fp8,
|
|
k,
|
|
weights,
|
|
quant_block_size,
|
|
scale_fmt,
|
|
topk_tokens,
|
|
head_dim,
|
|
max_model_len,
|
|
total_seq_lens,
|
|
topk_indices_buffer,
|
|
)
|
|
|
|
assert topk_indices_buffer is not None
|
|
|
|
attn_metadata = attn_metadata[k_cache_prefix]
|
|
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
|
slot_mapping = attn_metadata.slot_mapping
|
|
has_decode = attn_metadata.num_decodes > 0
|
|
has_prefill = attn_metadata.num_prefills > 0
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
|
|
indexer_k_cache(
|
|
k,
|
|
kv_cache,
|
|
slot_mapping,
|
|
)
|
|
|
|
topk_indices_buffer[:hidden_states.shape[1]] = -1
|
|
if has_prefill:
|
|
prefill_metadata = attn_metadata.prefill
|
|
for chunk in prefill_metadata.chunks:
|
|
k_bf16 = torch.empty([chunk.total_seq_lens, head_dim],
|
|
device=k.device,
|
|
dtype=torch.bfloat16)
|
|
k_scale = None
|
|
cp_gather_indexer_k_quant_cache(
|
|
kv_cache,
|
|
k_bf16,
|
|
k_scale,
|
|
chunk.block_table,
|
|
chunk.cu_seq_lens,
|
|
chunk.num_reqs,
|
|
)
|
|
|
|
logits = bf16_mqa_logits(
|
|
q_fp8[chunk.token_start:chunk.token_end],
|
|
k_bf16,
|
|
weights[chunk.token_start:chunk.token_end],
|
|
chunk.cu_seqlen_ks,
|
|
chunk.cu_seqlen_ke,
|
|
)
|
|
|
|
# [TODO] topk is not aligned with cpu if elements are -inf
|
|
topk_indices = logits.cpu().topk(min(topk_tokens,
|
|
logits.shape[-1]),
|
|
dim=-1)[1].supa()
|
|
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
|
mask_lo = topk_indices >= 0
|
|
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
|
chunk.cu_seqlen_ks)[:, None] < 0
|
|
mask = torch.full_like(topk_indices,
|
|
False,
|
|
dtype=torch.bool,
|
|
device=topk_indices.device)
|
|
mask = mask_lo & mask_hi
|
|
topk_indices = topk_indices.masked_fill(~mask, -1)
|
|
topk_indices_buffer[
|
|
chunk.token_start:chunk.token_end, :topk_indices.
|
|
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
|
|
|
if has_decode:
|
|
decode_metadata = attn_metadata.decode
|
|
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
|
# we only have [num_block, block_size, head_dim],
|
|
kv_cache = kv_cache.unsqueeze(-2)
|
|
decode_lens = decode_metadata.decode_lens
|
|
if decode_metadata.requires_padding:
|
|
# pad in edge case where we have short chunked prefill length <
|
|
# decode_threshold since we unstrictly split
|
|
# prefill and decode by decode_threshold
|
|
# (currently set to 1 + speculative tokens)
|
|
padded_q_fp8_decode_tokens = pack_seq_triton(
|
|
q_fp8[:num_decode_tokens], decode_lens)
|
|
else:
|
|
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
|
decode_lens.shape[0], -1, *q_fp8.shape[1:])
|
|
# TODO: move and optimize below logic with triton kernels
|
|
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
|
next_n = padded_q_fp8_decode_tokens.shape[1]
|
|
assert batch_size == decode_metadata.seq_lens.shape[0]
|
|
num_padded_tokens = batch_size * next_n
|
|
logits = _ref_fp8_paged_mqa_logits(
|
|
padded_q_fp8_decode_tokens,
|
|
kv_cache,
|
|
weights[:num_padded_tokens],
|
|
decode_metadata.seq_lens,
|
|
decode_metadata.block_table,
|
|
max_model_len=max_model_len,
|
|
)
|
|
# padded query len
|
|
current_device = padded_q_fp8_decode_tokens.device
|
|
padded_num_tokens = batch_size * next_n
|
|
positions = torch.arange(max_model_len,
|
|
device=current_device).unsqueeze(0).expand(
|
|
batch_size * next_n, -1)
|
|
row_indices = torch.arange(padded_num_tokens,
|
|
device=current_device) // next_n
|
|
next_n_offset = torch.arange(
|
|
padded_num_tokens,
|
|
device=padded_q_fp8_decode_tokens.device) % next_n
|
|
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
|
next_n_offset).unsqueeze(1)
|
|
# index_end_pos: [B * N, 1]
|
|
mask = positions <= index_end_pos
|
|
# mask: [B * N, L]
|
|
logits = logits.masked_fill(~mask, float('-inf'))
|
|
# [TODO] topk is not supported
|
|
device = logits.device
|
|
logits = logits.to('cpu')
|
|
topk_indices = logits.topk(topk_tokens,
|
|
dim=-1)[1].to(torch.int32) # [B * N, K]
|
|
topk_indices = topk_indices.to(device)
|
|
# ensure we don't set indices for the top k
|
|
# that is out of range(masked already)
|
|
# this will happen if context length is shorter than K
|
|
topk_indices[topk_indices > index_end_pos] = -1
|
|
if decode_metadata.requires_padding:
|
|
# if padded, we need to unpack
|
|
# the topk indices removing padded tokens
|
|
topk_indices = unpack_seq_triton(
|
|
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
|
decode_lens)
|
|
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
|
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
|
|
|
return topk_indices_buffer
|
|
|
|
|
|
class SupaIndexer(Indexer):
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
config: Union[DeepseekV2Config, DeepseekV3Config],
|
|
hidden_size: int,
|
|
q_lora_rank: Optional[int],
|
|
quant_config: Optional[QuantizationConfig],
|
|
cache_config: Optional[CacheConfig],
|
|
topk_indices_buffer: Optional[torch.Tensor] = None,
|
|
prefix: str = "") -> None:
|
|
super().__init__(
|
|
vllm_config=vllm_config,
|
|
config=config,
|
|
hidden_size=hidden_size,
|
|
q_lora_rank=q_lora_rank,
|
|
quant_config=quant_config,
|
|
cache_config=cache_config,
|
|
topk_indices_buffer=topk_indices_buffer,
|
|
prefix=prefix,
|
|
)
|
|
self.n_head = config.index_n_heads # 64
|
|
self.weights_proj = ReplicatedLinear(hidden_size,
|
|
self.n_head,
|
|
bias=False,
|
|
quant_config=None,
|
|
prefix=f"{prefix}.weights_proj")
|
|
self.k_cache.dtype = torch.bfloat16
|
|
self.k_cache.head_dim = config.index_head_dim
|
|
self.topk_indices_buffer.fill_(0)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
|
|
rotary_emb) -> torch.Tensor:
|
|
q, _ = self.wq_b(qr)
|
|
q = q.view(-1, self.n_head, self.head_dim)
|
|
q_pe, q_nope = torch.split(
|
|
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
|
|
|
k, _ = self.wk(hidden_states)
|
|
k = k.view(-1, self.head_dim)
|
|
k = self.k_norm(k)
|
|
k_pe, k_nope = torch.split(
|
|
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
|
|
|
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
|
q = torch.cat([q_pe, q_nope], dim=-1)
|
|
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
|
|
|
# we only quant q here since k quant is fused with cache insertion
|
|
q = q.view(-1, self.head_dim)
|
|
support_fp8 = False
|
|
if support_fp8:
|
|
q_fp8, q_scale = per_token_group_quant_fp8(
|
|
q,
|
|
self.quant_block_size,
|
|
column_major_scales=False,
|
|
use_ue8m0=self.scale_fmt is not None)
|
|
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
|
q_scale = q_scale.view(-1, self.n_head, 1)
|
|
|
|
weights, _ = self.weights_proj(hidden_states)
|
|
weights = weights.unsqueeze(
|
|
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
|
weights = weights.squeeze(-1)
|
|
|
|
return torch.ops.vllm.sparse_attn_indexer(
|
|
hidden_states,
|
|
self.k_cache.prefix,
|
|
self.k_cache.kv_cache[0],
|
|
q_fp8,
|
|
k,
|
|
weights,
|
|
self.quant_block_size,
|
|
self.scale_fmt,
|
|
self.topk_tokens,
|
|
self.head_dim,
|
|
self.max_model_len,
|
|
self.max_total_seq_len,
|
|
self.topk_indices_buffer,
|
|
)
|
|
else:
|
|
q = q.view(-1, self.n_head, self.head_dim)
|
|
weights, _ = self.weights_proj(hidden_states)
|
|
weights = weights.view(-1, self.n_head)
|
|
weights = weights.unsqueeze(
|
|
-1) * self.softmax_scale * self.n_head**-0.5
|
|
weights = weights.squeeze(-1)
|
|
|
|
return sparse_attn_indexer(
|
|
hidden_states,
|
|
self.k_cache.prefix,
|
|
self.k_cache.kv_cache[0],
|
|
q,
|
|
k,
|
|
weights,
|
|
self.quant_block_size,
|
|
self.scale_fmt,
|
|
self.topk_tokens,
|
|
self.head_dim,
|
|
self.max_model_len,
|
|
self.max_total_seq_len,
|
|
self.topk_indices_buffer,
|
|
)
|
|
|
|
|
|
@patch_to(DeepseekV2Model)
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
|
|
|
|
hidden_states = hidden_states.unsqueeze(0)
|
|
for layer in self.layers[self.start_layer:self.end_layer]:
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states":
|
|
hidden_states.squeeze(0)
|
|
if hidden_states is not None else hidden_states,
|
|
"residual":
|
|
residual.squeeze(0) if residual is not None else residual
|
|
})
|
|
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states.squeeze(0)
|
|
|
|
|
|
@patch_to(DeepseekV2ForCausalLM)
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super(DeepseekV2ForCausalLM, self).__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
model_config = vllm_config.model_config
|
|
model_config.use_ds_mla = True
|
|
is_v32 = hasattr(config, "index_topk")
|
|
if is_v32:
|
|
model_config.use_ds_mla_sparse = True
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
|
|
# `packed_modules_mapping` needs to be modified before
|
|
# initializing DeepseekV2Model, as it is passed inplace to
|
|
# quantization config init and may be used to select the
|
|
# quant_method for relevant layers during initialization.
|
|
self.fuse_qkv_a_proj = hasattr(
|
|
config, "q_lora_rank") and config.q_lora_rank is not None
|
|
if self.fuse_qkv_a_proj:
|
|
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
|
"q_a_proj",
|
|
"kv_a_proj_with_mqa",
|
|
]
|
|
|
|
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
if get_pp_group().is_last_rank:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
|
|
@patch_to(DeepseekV2ForCausalLM)
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
("fused_qkv_a_proj", "q_a_proj", 0),
|
|
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
|
]
|
|
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.n_routed_experts)
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
|
if spec_layer is not None:
|
|
continue # skip spec decode layers for main model
|
|
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
# Skip non-stacked layers and experts (experts handled below).
|
|
if weight_name not in name:
|
|
continue
|
|
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
# Since we handle the experts below in expert_params_mapping,
|
|
# we need to skip here BEFORE we update the name, otherwise
|
|
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
# will then be updated below in expert_params_mapping
|
|
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
if (("mlp.experts." in name) and name not in params_dict):
|
|
continue
|
|
name_mapped = name.replace(weight_name, param_name)
|
|
|
|
# QKV fusion is optional, fall back to normal
|
|
# weight loading if it's not enabled
|
|
# if go with fusion option, then update name
|
|
if ((param_name == "fused_qkv_a_proj")
|
|
and name_mapped not in params_dict):
|
|
continue
|
|
else:
|
|
name = name_mapped
|
|
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
if name not in params_dict:
|
|
# logger.debug(f'skip {name}')
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
# weight layout infer
|
|
if name.find("norm.weight") != -1 or name.find(
|
|
"e_score_correction_bias") != -1:
|
|
param.data = param.data.to(torch.float32)
|
|
torch.supa.empty_cache()
|
|
break
|
|
else:
|
|
for mapping in expert_params_mapping:
|
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
if name not in params_dict:
|
|
# logger.debug(f'skip {name}')
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param,
|
|
loaded_weight,
|
|
name,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id)
|
|
# weight layout infer
|
|
if name.find("norm.weight") != -1 or name.find(
|
|
"e_score_correction_bias") != -1:
|
|
param.data = param.data.to(torch.float32)
|
|
torch.supa.empty_cache()
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
if name not in params_dict:
|
|
# logger.debug(f'skip {name}')
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
# weight layout infer
|
|
if name.find("norm.weight") != -1 or name.find(
|
|
"e_score_correction_bias") != -1:
|
|
param.data = param.data.to(torch.float32)
|
|
torch.supa.empty_cache()
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
vllm.model_executor.models.deepseek_v2.DeepseekV2MLP = MergedGateUpMLPSiluL2
|
|
logger.debug('[Patch] patch DeepSeekV2 MLP with MergedGateUpMLPSiluL2')
|
|
vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = DeepseekV2MoE
|
|
logger.debug('[Patch] patch DeepSeekV2 MoE with DeepseekV2MoE')
|
|
vllm.model_executor.models.deepseek_v2.DeepseekV2MLAAttention = SupaDeepseekV2MLAAttention
|
|
logger.debug('[Patch] patch DeepSeekV2 MLA with SupaDeepseekV2MLAAttention')
|
|
vllm.model_executor.models.deepseek_v2.Indexer = SupaIndexer
|
|
logger.debug('[Patch] patch DeepSeekV2 Indexer with SupaIndexer')
|
|
vllm.model_executor.models.deepseek_v2.MultiHeadLatentAttention = SupaMultiHeadLatentAttention
|
|
logger.debug(
|
|
'[Patch] patch DeepSeekV2 MultiHeadLatentAttention with SupaMultiHeadLatentAttention'
|
|
)
|
|
|
|
# vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.packed_modules_mapping = {
|
|
# "gate_up_proj": ["gate_proj", "up_proj"],
|
|
# # "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
|
# }
|
|
# logger.debug(
|
|
# '[Patch] patch DeepseekV2ForCausalLM with SupportsQuant packed_modules_mapping'
|
|
# )
|