From e9fc2ac7b6115bf3a690ccc4b6f52bb77d5c7b7a Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Mon, 14 Apr 2025 22:56:39 +0800 Subject: [PATCH] [PD Bug] fix MLA get_contiguous_buf_infos error (#5384) --- python/sglang/srt/mem_cache/memory_pool.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index ae3d10e2f..b35419578 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -444,15 +444,10 @@ class MLATokenToKVPool(KVCache): # for disagg def get_contiguous_buf_infos(self): - kv_data_ptrs = [ - self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) - ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] - kv_data_lens = [ - self.get_key_buffer(i).nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] - kv_item_lens = [ - self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] + # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. + kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] + kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] + kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] return kv_data_ptrs, kv_data_lens, kv_item_lens def get_key_buffer(self, layer_id: int):