Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
This commit is contained in:
@@ -57,6 +57,8 @@ def _fwd_kernel(
|
||||
stride_buf_vh,
|
||||
stride_req_to_tokens_b,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
@@ -75,8 +77,10 @@ def _fwd_kernel(
|
||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||
|
||||
offs_q = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
@@ -85,10 +89,20 @@ def _fwd_kernel(
|
||||
)
|
||||
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
offs_qpe = (
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_qbs
|
||||
+ cur_head * stride_qh
|
||||
+ offs_dpe[None, :]
|
||||
)
|
||||
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||
|
||||
# stage1: compute scores with prefix
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
|
||||
@@ -110,6 +124,18 @@ def _fwd_kernel(
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
offs_kv_loc[None, :] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
@@ -125,7 +151,7 @@ def _fwd_kernel(
|
||||
offs_buf_v = (
|
||||
offs_kv_loc[:, None] * stride_buf_vbs
|
||||
+ cur_kv_head * stride_buf_vh
|
||||
+ offs_d[None, :]
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
@@ -150,6 +176,21 @@ def _fwd_kernel(
|
||||
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_kpe = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||
* stride_kbs
|
||||
+ cur_kv_head * stride_kh
|
||||
+ offs_dpe[:, None]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Extend + offs_kpe,
|
||||
mask=mask_n[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe)
|
||||
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
@@ -169,7 +210,7 @@ def _fwd_kernel(
|
||||
offs_v = (
|
||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||
+ cur_kv_head * stride_vh
|
||||
+ offs_d[None, :]
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
@@ -181,7 +222,7 @@ def _fwd_kernel(
|
||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* stride_obs
|
||||
+ cur_head * stride_oh
|
||||
+ offs_d[None, :]
|
||||
+ offs_dv[None, :]
|
||||
)
|
||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||
|
||||
@@ -217,8 +258,17 @@ def extend_attention_fwd(
|
||||
o_extend.shape[-1],
|
||||
)
|
||||
|
||||
assert Lq == Lk and Lk == Lv and Lv == Lo
|
||||
assert Lq in {16, 32, 64, 128, 256}
|
||||
assert Lq == Lk and Lv == Lo
|
||||
assert Lq in {16, 32, 64, 128, 256, 576}
|
||||
assert Lv in {16, 32, 64, 128, 256, 512}
|
||||
|
||||
if Lq == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
else:
|
||||
BLOCK_DMODEL = Lq
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = Lv
|
||||
|
||||
if CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
||||
@@ -260,7 +310,9 @@ def extend_attention_fwd(
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
BLOCK_DMODEL=Lq,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=num_warps,
|
||||
|
||||
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
logit_cap: int = -1,
|
||||
v_head_dim: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.qk_head_dim = head_dim
|
||||
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||
self.scaling = scaling
|
||||
self.layer_id = layer_id
|
||||
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
if (
|
||||
not global_server_args_dict.get("disable_flashinfer", False)
|
||||
and self.qk_head_dim == self.v_head_dim
|
||||
):
|
||||
self.extend_forward = self.extend_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
else:
|
||||
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
|
||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||
|
||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
extend_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
|
||||
return o
|
||||
|
||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||
o = torch.empty_like(q)
|
||||
if self.qk_head_dim != self.v_head_dim:
|
||||
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
token_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
||||
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||
input_metadata.req_to_token_pool.req_to_token,
|
||||
input_metadata.req_pool_indices,
|
||||
input_metadata.triton_start_loc,
|
||||
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
|
||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||
|
||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||
return self.extend_forward(q, k, v, input_metadata)
|
||||
|
||||
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
|
||||
att_stride_h,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
):
|
||||
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
||||
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
block_stard_index = start_n * BLOCK_N
|
||||
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
|
||||
other=0.0,
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
if BLOCK_DPE > 0:
|
||||
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
||||
offs_buf_kpe = (
|
||||
k_loc[:, None] * stride_buf_kbs
|
||||
+ cur_kv_head * stride_buf_kh
|
||||
+ offs_dpe[None, :]
|
||||
)
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=offs_n_new[:, None] < cur_batch_end_index,
|
||||
other=0.0,
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
||||
att_value *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
|
||||
# shape constraints
|
||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||
assert Lq == Lk
|
||||
assert Lk in {16, 32, 64, 128, 256}
|
||||
assert Lk in {16, 32, 64, 128, 256, 576}
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
else:
|
||||
BLOCK_DMODEL = Lk
|
||||
BLOCK_DPE = 0
|
||||
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
|
||||
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_N=BLOCK,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
|
||||
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
@@ -39,6 +39,7 @@ global_server_args_dict = {
|
||||
"disable_flashinfer": False,
|
||||
"disable_flashinfer_sampling": False,
|
||||
"attention_reduce_in_fp32": False,
|
||||
"enable_mla": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -289,7 +290,7 @@ class Batch:
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
tree_cache: RadixCache
|
||||
|
||||
# Batched arguments to model runner
|
||||
@@ -780,7 +781,7 @@ class InputMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: TokenToKVPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
|
||||
# For extend
|
||||
extend_seq_lens: torch.Tensor
|
||||
|
||||
@@ -57,32 +57,18 @@ class ReqToTokenPool:
|
||||
self.can_use_mem_size = len(self.mem_state)
|
||||
|
||||
|
||||
class TokenToKVPool:
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
):
|
||||
self.size = size
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
# Prefetch buffer
|
||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||
self.prefetch_chunk_size = 512
|
||||
@@ -90,15 +76,6 @@ class TokenToKVPool:
|
||||
self.can_use_mem_size = self.size
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||
|
||||
def available_size(self):
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
|
||||
@@ -139,3 +116,67 @@ class TokenToKVPool:
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state[0] = False
|
||||
|
||||
|
||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
):
|
||||
super().__init__(size)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
):
|
||||
super().__init__(size)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.kv_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
|
||||
|
||||
class AttentionArch(IntEnum):
|
||||
MLA = auto()
|
||||
MHA = auto()
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,6 +61,11 @@ class ModelConfig:
|
||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
||||
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
else:
|
||||
self.attention_arch = AttentionArch.MHA
|
||||
|
||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
||||
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
||||
|
||||
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -86,6 +91,7 @@ class ModelRunner:
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -193,15 +199,23 @@ class ModelRunner:
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||
cell_size = (
|
||||
head_num
|
||||
* head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
):
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
else:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(self.tp_size)
|
||||
* self.model_config.head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
@@ -241,13 +255,28 @@ class ModelRunner:
|
||||
max_num_reqs,
|
||||
self.model_config.context_len + 8,
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.disable_flashinfer = True
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Memory pool end. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
|
||||
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
|
||||
|
||||
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int,
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_id=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# O projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
rope_scaling["type"] = "deepseek_yarn"
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_local_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=1,
|
||||
layer_id=layer_id,
|
||||
v_head_dim=self.kv_lora_rank,
|
||||
)
|
||||
|
||||
kv_b_proj = self.kv_b_proj
|
||||
w_kc, w_vc = kv_b_proj.weight.unflatten(
|
||||
0, (-1, qk_nope_head_dim + v_head_dim)
|
||||
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
||||
self.w_kc = w_kc
|
||||
self.w_vc = w_vc
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
q_len = hidden_states.shape[0]
|
||||
q_input = hidden_states.new_empty(
|
||||
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
||||
)
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
q = self.q_a_layernorm(q)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
else:
|
||||
q = self.q_proj(hidden_states)[0].view(
|
||||
-1, self.num_local_heads, self.qk_head_dim
|
||||
)
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
q_nope_out = q_input[..., : self.kv_lora_rank]
|
||||
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
||||
|
||||
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
|
||||
k_pe = k_input[..., self.kv_lora_rank :]
|
||||
v_input = k_input[..., : self.kv_lora_rank]
|
||||
v_input = self.kv_a_layernorm(v_input.contiguous())
|
||||
k_input[..., : self.kv_lora_rank] = v_input
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q_input[..., self.kv_lora_rank :] = q_pe
|
||||
k_input[..., self.kv_lora_rank :] = k_pe
|
||||
|
||||
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
attn_bmm_output = attn_output.new_empty(
|
||||
q_len, self.num_local_heads, self.v_head_dim
|
||||
)
|
||||
torch.bmm(
|
||||
attn_output.transpose(0, 1),
|
||||
self.w_vc.transpose(1, 2).contiguous(),
|
||||
out=attn_bmm_output.transpose(0, 1),
|
||||
)
|
||||
|
||||
attn_output = attn_bmm_output.flatten(1, 2)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=(
|
||||
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||
),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
else:
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
q_lora_rank=(
|
||||
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||
),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
|
||||
@@ -80,6 +80,7 @@ class ServerArgs:
|
||||
disable_disk_cache: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
enable_p2p_check: bool = False
|
||||
enable_mla: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
efficient_weight_load: bool = False
|
||||
|
||||
@@ -393,6 +394,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mla",
|
||||
action="store_true",
|
||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user