diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index ae3d10e2f..b35419578 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -444,15 +444,10 @@ class MLATokenToKVPool(KVCache): # for disagg def get_contiguous_buf_infos(self): - kv_data_ptrs = [ - self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) - ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] - kv_data_lens = [ - self.get_key_buffer(i).nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] - kv_item_lens = [ - self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] + # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. + kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] + kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] + kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] return kv_data_ptrs, kv_data_lens, kv_item_lens def get_key_buffer(self, layer_id: int):