Support MLA for DeepSeek-V2 with Triton - step 1 (#905)

This commit is contained in:
Ke Bao
2024-08-05 01:40:33 +08:00
committed by GitHub
parent f4d9953d9d
commit e1eae1fd15
10 changed files with 439 additions and 78 deletions

View File

@@ -29,7 +29,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -39,6 +39,7 @@ global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"enable_mla": False,
}
@@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache
# Batched arguments to model runner
@@ -780,7 +781,7 @@ class InputMetadata:
seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool
# For extend
extend_seq_lens: torch.Tensor