From 5fbafbb8f8bce429f1c06a53bfd45bcd0eb26cc7 Mon Sep 17 00:00:00 2001 From: huangtingwei <141888744+huangtingwei9988@users.noreply.github.com> Date: Mon, 14 Apr 2025 03:37:26 +0800 Subject: [PATCH] fix MLATokenToKVPoolHost get_size_per_token bug (#5161) Co-authored-by: AniZpZ --- python/sglang/srt/mem_cache/memory_pool.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 84198c011..ae3d10e2f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache): self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim self.layer_num = self.device_pool.layer_num - return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize + return ( + (self.kv_lora_rank + self.qk_rope_head_dim) + * 1 + * self.dtype.itemsize + * self.layer_num + ) def init_kv_buffer(self): return torch.empty(