diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 84198c011..ae3d10e2f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache): self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim self.layer_num = self.device_pool.layer_num - return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize + return ( + (self.kv_lora_rank + self.qk_rope_head_dim) + * 1 + * self.dtype.itemsize + * self.layer_num + ) def init_kv_buffer(self): return torch.empty(