diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 31866e010..84198c011 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -442,6 +442,19 @@ class MLATokenToKVPool(KVCache): self.layer_transfer_counter = None + # for disagg + def get_contiguous_buf_infos(self): + kv_data_ptrs = [ + self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) + ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] + kv_data_lens = [ + self.get_key_buffer(i).nbytes for i in range(self.layer_num) + ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] + kv_item_lens = [ + self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) + ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] + return kv_data_ptrs, kv_data_lens, kv_item_lens + def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id)