Support MLA for DeepSeek-V2 with Triton - step 1 (#905)

This commit is contained in:
Ke Bao
2024-08-05 01:40:33 +08:00
committed by GitHub
parent f4d9953d9d
commit e1eae1fd15
10 changed files with 439 additions and 78 deletions

View File

@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
InputMetadata,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
@@ -86,6 +91,7 @@ class ModelRunner:
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
@@ -193,15 +199,23 @@ class ModelRunner:
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
head_dim = self.model_config.head_dim
head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = (
head_num
* head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* torch._utils._element_size(self.dtype)
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
@@ -241,13 +255,28 @@ class ModelRunner:
max_num_reqs,
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu={self.gpu_id}] Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"