fix MLATokenToKVPoolHost get_size_per_token bug (#5161)

Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
This commit is contained in:
huangtingwei
2025-04-14 03:37:26 +08:00
committed by GitHub
parent a9499885e9
commit 5fbafbb8f8

View File

@@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)
def init_kv_buffer(self):
return torch.empty(