Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
This commit is contained in:
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_config import AttentionArch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -86,6 +91,7 @@ class ModelRunner:
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -193,15 +199,23 @@ class ModelRunner:
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||
cell_size = (
|
||||
head_num
|
||||
* head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
):
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
* self.model_config.num_hidden_layers
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
else:
|
||||
cell_size = (
|
||||
self.model_config.get_num_kv_heads(self.tp_size)
|
||||
* self.model_config.head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
@@ -241,13 +255,28 @@ class ModelRunner:
|
||||
max_num_reqs,
|
||||
self.model_config.context_len + 8,
|
||||
)
|
||||
self.token_to_kv_pool = TokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.disable_flashinfer = True
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu={self.gpu_id}] Memory pool end. "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
|
||||
Reference in New Issue
Block a user