fix MLATokenToKVPoolHost get_size_per_token bug (#5161)
Co-authored-by: AniZpZ <zhuangsen.zp@antgroup.com>
This commit is contained in:
@@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
||||||
self.layer_num = self.device_pool.layer_num
|
self.layer_num = self.device_pool.layer_num
|
||||||
|
|
||||||
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
return (
|
||||||
|
(self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
|
* 1
|
||||||
|
* self.dtype.itemsize
|
||||||
|
* self.layer_num
|
||||||
|
)
|
||||||
|
|
||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
|
|||||||
Reference in New Issue
Block a user