[minor fix] llama4 hybrid memory (#7950)
This commit is contained in:
@@ -520,8 +520,13 @@ class SWAKVPool(KVCache):
|
|||||||
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
||||||
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
k_size, v_size = self.get_kv_size_bytes()
|
||||||
|
self.mem_usage = (k_size + v_size) / GB
|
||||||
|
|
||||||
def get_kv_size_bytes(self):
|
def get_kv_size_bytes(self):
|
||||||
raise NotImplementedError
|
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
||||||
|
k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
|
||||||
|
return k_size + k_size_swa, v_size + v_size_swa
|
||||||
|
|
||||||
def get_contiguous_buf_infos(self):
|
def get_contiguous_buf_infos(self):
|
||||||
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
||||||
|
|||||||
Reference in New Issue
Block a user