diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 2e0766222..95ad7b538 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -520,8 +520,13 @@ class SWAKVPool(KVCache): self.layers_mapping[global_layer_id] = (swa_layer_id, True) self.full_to_swa_index_mapping: Optional[torch.Tensor] = None + k_size, v_size = self.get_kv_size_bytes() + self.mem_usage = (k_size + v_size) / GB + def get_kv_size_bytes(self): - raise NotImplementedError + k_size, v_size = self.full_kv_pool.get_kv_size_bytes() + k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes() + return k_size + k_size_swa, v_size + v_size_swa def get_contiguous_buf_infos(self): full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (