fix MLATokenToKVPoolHost get_size_per_token bug (#5161)
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
This commit is contained in:
@@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||
self.layer_num = self.device_pool.layer_num
|
||||
|
||||
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
||||
return (
|
||||
(self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
* 1
|
||||
* self.dtype.itemsize
|
||||
* self.layer_num
|
||||
)
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
|
||||
Reference in New Issue
Block a user