Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
This commit is contained in:
0
benchmark/gsm8k/download_data.sh
Normal file → Executable file
0
benchmark/gsm8k/download_data.sh
Normal file → Executable file
@@ -57,6 +57,8 @@ def _fwd_kernel(
|
|||||||
stride_buf_vh,
|
stride_buf_vh,
|
||||||
stride_req_to_tokens_b,
|
stride_req_to_tokens_b,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DPE: tl.constexpr,
|
||||||
|
BLOCK_DV: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
@@ -75,8 +77,10 @@ def _fwd_kernel(
|
|||||||
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
offs_m = tl.arange(0, BLOCK_M)
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||||
|
|
||||||
offs_q = (
|
offs_q = (
|
||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_qbs
|
* stride_qbs
|
||||||
@@ -85,10 +89,20 @@ def _fwd_kernel(
|
|||||||
)
|
)
|
||||||
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
|
offs_qpe = (
|
||||||
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* stride_qbs
|
||||||
|
+ cur_head * stride_qh
|
||||||
|
+ offs_dpe[None, :]
|
||||||
|
)
|
||||||
|
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
||||||
|
|
||||||
# stage1: compute scores with prefix
|
# stage1: compute scores with prefix
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
||||||
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||||
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
|
||||||
@@ -110,6 +124,18 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_kpe = (
|
||||||
|
offs_kv_loc[None, :] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
|
||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
@@ -125,7 +151,7 @@ def _fwd_kernel(
|
|||||||
offs_buf_v = (
|
offs_buf_v = (
|
||||||
offs_kv_loc[:, None] * stride_buf_vbs
|
offs_kv_loc[:, None] * stride_buf_vbs
|
||||||
+ cur_kv_head * stride_buf_vh
|
+ cur_kv_head * stride_buf_vh
|
||||||
+ offs_d[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
@@ -150,6 +176,21 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_kpe = (
|
||||||
|
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
||||||
|
* stride_kbs
|
||||||
|
+ cur_kv_head * stride_kh
|
||||||
|
+ offs_dpe[:, None]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Extend + offs_kpe,
|
||||||
|
mask=mask_n[None, :],
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
qk += tl.dot(qpe, kpe)
|
||||||
|
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
|
||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
@@ -169,7 +210,7 @@ def _fwd_kernel(
|
|||||||
offs_v = (
|
offs_v = (
|
||||||
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
||||||
+ cur_kv_head * stride_vh
|
+ cur_kv_head * stride_vh
|
||||||
+ offs_d[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
@@ -181,7 +222,7 @@ def _fwd_kernel(
|
|||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_obs
|
* stride_obs
|
||||||
+ cur_head * stride_oh
|
+ cur_head * stride_oh
|
||||||
+ offs_d[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||||
|
|
||||||
@@ -217,8 +258,17 @@ def extend_attention_fwd(
|
|||||||
o_extend.shape[-1],
|
o_extend.shape[-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Lq == Lk and Lk == Lv and Lv == Lo
|
assert Lq == Lk and Lv == Lo
|
||||||
assert Lq in {16, 32, 64, 128, 256}
|
assert Lq in {16, 32, 64, 128, 256, 576}
|
||||||
|
assert Lv in {16, 32, 64, 128, 256, 512}
|
||||||
|
|
||||||
|
if Lq == 576:
|
||||||
|
BLOCK_DMODEL = 512
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
else:
|
||||||
|
BLOCK_DMODEL = Lq
|
||||||
|
BLOCK_DPE = 0
|
||||||
|
BLOCK_DV = Lv
|
||||||
|
|
||||||
if CUDA_CAPABILITY[0] >= 8:
|
if CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
||||||
@@ -260,7 +310,9 @@ def extend_attention_fwd(
|
|||||||
v_buffer.stride(0),
|
v_buffer.stride(0),
|
||||||
v_buffer.stride(1),
|
v_buffer.stride(1),
|
||||||
req_to_tokens.stride(0),
|
req_to_tokens.stride(0),
|
||||||
BLOCK_DMODEL=Lq,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
|
BLOCK_DV=BLOCK_DV,
|
||||||
BLOCK_M=BLOCK_M,
|
BLOCK_M=BLOCK_M,
|
||||||
BLOCK_N=BLOCK_N,
|
BLOCK_N=BLOCK_N,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
|
|||||||
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
logit_cap: int = -1,
|
logit_cap: int = -1,
|
||||||
|
v_head_dim: int = -1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
self.tp_k_head_num = num_kv_heads
|
self.tp_k_head_num = num_kv_heads
|
||||||
self.tp_v_head_num = num_kv_heads
|
self.tp_v_head_num = num_kv_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
|
self.qk_head_dim = head_dim
|
||||||
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
||||||
self.scaling = scaling
|
self.scaling = scaling
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
|
||||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
if (
|
||||||
|
not global_server_args_dict.get("disable_flashinfer", False)
|
||||||
|
and self.qk_head_dim == self.v_head_dim
|
||||||
|
):
|
||||||
self.extend_forward = self.extend_forward_flashinfer
|
self.extend_forward = self.extend_forward_flashinfer
|
||||||
self.decode_forward = self.decode_forward_flashinfer
|
self.decode_forward = self.decode_forward_flashinfer
|
||||||
else:
|
else:
|
||||||
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
|
|||||||
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
||||||
|
|
||||||
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||||
o = torch.empty_like(q)
|
if self.qk_head_dim != self.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
extend_attention_fwd(
|
extend_attention_fwd(
|
||||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
v.contiguous(),
|
v.contiguous(),
|
||||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
||||||
o = torch.empty_like(q)
|
if self.qk_head_dim != self.v_head_dim:
|
||||||
|
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
||||||
|
else:
|
||||||
|
o = torch.empty_like(q)
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
token_attention_fwd(
|
token_attention_fwd(
|
||||||
q.view(-1, self.tp_q_head_num, self.head_dim),
|
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||||
o.view(-1, self.tp_q_head_num, self.head_dim),
|
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
||||||
input_metadata.req_to_token_pool.req_to_token,
|
input_metadata.req_to_token_pool.req_to_token,
|
||||||
input_metadata.req_pool_indices,
|
input_metadata.req_pool_indices,
|
||||||
input_metadata.triton_start_loc,
|
input_metadata.triton_start_loc,
|
||||||
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
|
|||||||
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
||||||
|
|
||||||
def forward(self, q, k, v, input_metadata: InputMetadata):
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
||||||
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||||
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
||||||
return self.extend_forward(q, k, v, input_metadata)
|
return self.extend_forward(q, k, v, input_metadata)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
|
|||||||
att_stride_h,
|
att_stride_h,
|
||||||
kv_group_num: tl.constexpr,
|
kv_group_num: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_DPE: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
):
|
):
|
||||||
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
|
|||||||
|
|
||||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||||
|
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
|
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
||||||
|
|
||||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
block_stard_index = start_n * BLOCK_N
|
block_stard_index = start_n * BLOCK_N
|
||||||
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
|
|||||||
other=0.0,
|
other=0.0,
|
||||||
).to(REDUCE_TRITON_TYPE)
|
).to(REDUCE_TRITON_TYPE)
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
att_value = tl.sum(q[None, :] * k, 1)
|
||||||
|
if BLOCK_DPE > 0:
|
||||||
|
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
||||||
|
offs_buf_kpe = (
|
||||||
|
k_loc[:, None] * stride_buf_kbs
|
||||||
|
+ cur_kv_head * stride_buf_kh
|
||||||
|
+ offs_dpe[None, :]
|
||||||
|
)
|
||||||
|
kpe = tl.load(
|
||||||
|
K_Buffer + offs_buf_kpe,
|
||||||
|
mask=offs_n_new[:, None] < cur_batch_end_index,
|
||||||
|
other=0.0,
|
||||||
|
).to(REDUCE_TRITON_TYPE)
|
||||||
|
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
||||||
att_value *= sm_scale
|
att_value *= sm_scale
|
||||||
|
|
||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
|
|||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||||
assert Lq == Lk
|
assert Lq == Lk
|
||||||
assert Lk in {16, 32, 64, 128, 256}
|
assert Lk in {16, 32, 64, 128, 256, 576}
|
||||||
|
|
||||||
|
if Lk == 576:
|
||||||
|
BLOCK_DMODEL = 512
|
||||||
|
BLOCK_DPE = 64
|
||||||
|
else:
|
||||||
|
BLOCK_DMODEL = Lk
|
||||||
|
BLOCK_DPE = 0
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||||
|
|
||||||
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
|
|||||||
k_buffer.stride(1),
|
k_buffer.stride(1),
|
||||||
att_out.stride(0),
|
att_out.stride(0),
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
|
BLOCK_DPE=BLOCK_DPE,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
@@ -39,6 +39,7 @@ global_server_args_dict = {
|
|||||||
"disable_flashinfer": False,
|
"disable_flashinfer": False,
|
||||||
"disable_flashinfer_sampling": False,
|
"disable_flashinfer_sampling": False,
|
||||||
"attention_reduce_in_fp32": False,
|
"attention_reduce_in_fp32": False,
|
||||||
|
"enable_mla": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -289,7 +290,7 @@ class Batch:
|
|||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: TokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
tree_cache: RadixCache
|
tree_cache: RadixCache
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
@@ -780,7 +781,7 @@ class InputMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: TokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
extend_seq_lens: torch.Tensor
|
extend_seq_lens: torch.Tensor
|
||||||
|
|||||||
@@ -57,32 +57,18 @@ class ReqToTokenPool:
|
|||||||
self.can_use_mem_size = len(self.mem_state)
|
self.can_use_mem_size = len(self.mem_state)
|
||||||
|
|
||||||
|
|
||||||
class TokenToKVPool:
|
class BaseTokenToKVPool:
|
||||||
"""A memory pool that maps a token to its kv cache locations"""
|
"""A memory pool that maps a token to its kv cache locations"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
dtype: torch.dtype,
|
|
||||||
head_num: int,
|
|
||||||
head_dim: int,
|
|
||||||
layer_num: int,
|
|
||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||||
|
|
||||||
# [size, head_num, head_dim] for each layer
|
|
||||||
self.k_buffer = [
|
|
||||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
|
||||||
for _ in range(layer_num)
|
|
||||||
]
|
|
||||||
self.v_buffer = [
|
|
||||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
|
||||||
for _ in range(layer_num)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Prefetch buffer
|
# Prefetch buffer
|
||||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||||
self.prefetch_chunk_size = 512
|
self.prefetch_chunk_size = 512
|
||||||
@@ -90,15 +76,6 @@ class TokenToKVPool:
|
|||||||
self.can_use_mem_size = self.size
|
self.can_use_mem_size = self.size
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
|
||||||
return self.k_buffer[layer_id]
|
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
|
||||||
return self.v_buffer[layer_id]
|
|
||||||
|
|
||||||
def get_kv_buffer(self, layer_id: int):
|
|
||||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||||
|
|
||||||
@@ -139,3 +116,67 @@ class TokenToKVPool:
|
|||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state[0] = False
|
self.mem_state[0] = False
|
||||||
|
|
||||||
|
|
||||||
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
head_num: int,
|
||||||
|
head_dim: int,
|
||||||
|
layer_num: int,
|
||||||
|
):
|
||||||
|
super().__init__(size)
|
||||||
|
|
||||||
|
# [size, head_num, head_dim] for each layer
|
||||||
|
self.k_buffer = [
|
||||||
|
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
self.v_buffer = [
|
||||||
|
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
return self.k_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
return self.v_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_kv_buffer(self, layer_id: int):
|
||||||
|
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||||
|
|
||||||
|
|
||||||
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
layer_num: int,
|
||||||
|
):
|
||||||
|
super().__init__(size)
|
||||||
|
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.kv_buffer = [
|
||||||
|
torch.empty(
|
||||||
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||||
|
dtype=dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
return self.kv_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_value_buffer(self, layer_id: int):
|
||||||
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
||||||
|
|
||||||
|
def get_kv_buffer(self, layer_id: int):
|
||||||
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from enum import IntEnum, auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
|
|||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionArch(IntEnum):
|
||||||
|
MLA = auto()
|
||||||
|
MHA = auto()
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -55,6 +61,11 @@ class ModelConfig:
|
|||||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
||||||
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
||||||
self.head_dim = 256
|
self.head_dim = 256
|
||||||
|
self.attention_arch = AttentionArch.MLA
|
||||||
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||||
|
else:
|
||||||
|
self.attention_arch = AttentionArch.MHA
|
||||||
|
|
||||||
self.num_attention_heads = self.hf_config.num_attention_heads
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
||||||
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
||||||
|
|||||||
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
InputMetadata,
|
InputMetadata,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
MHATokenToKVPool,
|
||||||
|
MLATokenToKVPool,
|
||||||
|
ReqToTokenPool,
|
||||||
|
)
|
||||||
|
from sglang.srt.model_config import AttentionArch
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
@@ -86,6 +91,7 @@ class ModelRunner:
|
|||||||
"disable_flashinfer": server_args.disable_flashinfer,
|
"disable_flashinfer": server_args.disable_flashinfer,
|
||||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
|
"enable_mla": server_args.enable_mla,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -193,15 +199,23 @@ class ModelRunner:
|
|||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
if (
|
||||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
cell_size = (
|
and self.server_args.enable_mla
|
||||||
head_num
|
):
|
||||||
* head_dim
|
cell_size = (
|
||||||
* self.model_config.num_hidden_layers
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||||
* 2
|
* self.model_config.num_hidden_layers
|
||||||
* torch._utils._element_size(self.dtype)
|
* torch._utils._element_size(self.dtype)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cell_size = (
|
||||||
|
self.model_config.get_num_kv_heads(self.tp_size)
|
||||||
|
* self.model_config.head_dim
|
||||||
|
* self.model_config.num_hidden_layers
|
||||||
|
* 2
|
||||||
|
* torch._utils._element_size(self.dtype)
|
||||||
|
)
|
||||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||||
1 - self.mem_fraction_static
|
1 - self.mem_fraction_static
|
||||||
)
|
)
|
||||||
@@ -241,13 +255,28 @@ class ModelRunner:
|
|||||||
max_num_reqs,
|
max_num_reqs,
|
||||||
self.model_config.context_len + 8,
|
self.model_config.context_len + 8,
|
||||||
)
|
)
|
||||||
self.token_to_kv_pool = TokenToKVPool(
|
if (
|
||||||
self.max_total_num_tokens,
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
dtype=self.dtype,
|
and self.server_args.enable_mla
|
||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
):
|
||||||
head_dim=self.model_config.head_dim,
|
self.token_to_kv_pool = MLATokenToKVPool(
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
self.max_total_num_tokens,
|
||||||
)
|
dtype=self.dtype,
|
||||||
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||||
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||||
|
# FIXME: temporarily only Triton MLA is supported
|
||||||
|
self.server_args.disable_flashinfer = True
|
||||||
|
else:
|
||||||
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
dtype=self.dtype,
|
||||||
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
|
head_dim=self.model_config.head_dim,
|
||||||
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu={self.gpu_id}] Memory pool end. "
|
f"[gpu={self.gpu_id}] Memory pool end. "
|
||||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV2AttentionMLA(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_lora_rank: int,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
layer_id=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.num_heads = num_heads
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert num_heads % tp_size == 0
|
||||||
|
self.num_local_heads = num_heads // tp_size
|
||||||
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
# O projection.
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
rope_scaling["type"] = "deepseek_yarn"
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
qk_rope_head_dim,
|
||||||
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
|
self.attn = RadixAttention(
|
||||||
|
self.num_local_heads,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
layer_id=layer_id,
|
||||||
|
v_head_dim=self.kv_lora_rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_b_proj = self.kv_b_proj
|
||||||
|
w_kc, w_vc = kv_b_proj.weight.unflatten(
|
||||||
|
0, (-1, qk_nope_head_dim + v_head_dim)
|
||||||
|
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
||||||
|
self.w_kc = w_kc
|
||||||
|
self.w_vc = w_vc
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
q_len = hidden_states.shape[0]
|
||||||
|
q_input = hidden_states.new_empty(
|
||||||
|
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
||||||
|
)
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
q = self.q_a_proj(hidden_states)[0]
|
||||||
|
q = self.q_a_layernorm(q)
|
||||||
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||||
|
else:
|
||||||
|
q = self.q_proj(hidden_states)[0].view(
|
||||||
|
-1, self.num_local_heads, self.qk_head_dim
|
||||||
|
)
|
||||||
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
q_nope_out = q_input[..., : self.kv_lora_rank]
|
||||||
|
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
|
||||||
|
|
||||||
|
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
|
||||||
|
k_pe = k_input[..., self.kv_lora_rank :]
|
||||||
|
v_input = k_input[..., : self.kv_lora_rank]
|
||||||
|
v_input = self.kv_a_layernorm(v_input.contiguous())
|
||||||
|
k_input[..., : self.kv_lora_rank] = v_input
|
||||||
|
|
||||||
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
q_input[..., self.kv_lora_rank :] = q_pe
|
||||||
|
k_input[..., self.kv_lora_rank :] = k_pe
|
||||||
|
|
||||||
|
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
||||||
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
attn_bmm_output = attn_output.new_empty(
|
||||||
|
q_len, self.num_local_heads, self.v_head_dim
|
||||||
|
)
|
||||||
|
torch.bmm(
|
||||||
|
attn_output.transpose(0, 1),
|
||||||
|
self.w_vc.transpose(1, 2).contiguous(),
|
||||||
|
out=attn_bmm_output.transpose(0, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_bmm_output.flatten(1, 2)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.self_attn = DeepseekV2Attention(
|
if global_server_args_dict["enable_mla"]:
|
||||||
config=config,
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
hidden_size=self.hidden_size,
|
config=config,
|
||||||
num_heads=config.num_attention_heads,
|
hidden_size=self.hidden_size,
|
||||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
num_heads=config.num_attention_heads,
|
||||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
v_head_dim=config.v_head_dim,
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
|
v_head_dim=config.v_head_dim,
|
||||||
kv_lora_rank=config.kv_lora_rank,
|
q_lora_rank=(
|
||||||
rope_theta=rope_theta,
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||||
rope_scaling=rope_scaling,
|
),
|
||||||
max_position_embeddings=max_position_embeddings,
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
cache_config=cache_config,
|
rope_theta=rope_theta,
|
||||||
quant_config=quant_config,
|
rope_scaling=rope_scaling,
|
||||||
layer_id=layer_id,
|
max_position_embeddings=max_position_embeddings,
|
||||||
)
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.self_attn = DeepseekV2Attention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
|
v_head_dim=config.v_head_dim,
|
||||||
|
q_lora_rank=(
|
||||||
|
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||||
|
),
|
||||||
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
config.n_routed_experts is not None
|
config.n_routed_experts is not None
|
||||||
and layer_id >= config.first_k_dense_replace
|
and layer_id >= config.first_k_dense_replace
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class ServerArgs:
|
|||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
enable_p2p_check: bool = False
|
enable_p2p_check: bool = False
|
||||||
|
enable_mla: bool = False
|
||||||
attention_reduce_in_fp32: bool = False
|
attention_reduce_in_fp32: bool = False
|
||||||
efficient_weight_load: bool = False
|
efficient_weight_load: bool = False
|
||||||
|
|
||||||
@@ -393,6 +394,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-mla",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-reduce-in-fp32",
|
"--attention-reduce-in-fp32",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user