diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 23acf5222..db2ed2ae9 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, + kv_to_page_indices, poll_and_all_reduce, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -205,7 +206,10 @@ class DecodePreallocQueue: self.req_to_metadata_buffer_idx_allocator.alloc() ) assert decode_req.metadata_buffer_index is not None - decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index) + page_indices = kv_to_page_indices( + kv_indices, self.token_to_kv_pool_allocator.page_size + ) + decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) preallocated_reqs.append(decode_req) indices_to_remove.add(i) @@ -245,10 +249,30 @@ class DecodePreallocQueue: assert req_pool_indices is not None req.req_pool_idx = req_pool_indices[0] - kv_loc = self.token_to_kv_pool_allocator.alloc( - len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) - ) - + if self.token_to_kv_pool_allocator.page_size == 1: + kv_loc = self.token_to_kv_pool_allocator.alloc( + len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) + ) + else: + num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) + kv_loc = self.token_to_kv_pool_allocator.alloc_extend( + prefix_lens=torch.tensor( + [0], + dtype=torch.int64, + device=self.token_to_kv_pool_allocator.device, + ), + seq_lens=torch.tensor( + [num_tokens], + dtype=torch.int64, + device=self.token_to_kv_pool_allocator.device, + ), + last_loc=torch.tensor( + [-1], + dtype=torch.int64, + device=self.token_to_kv_pool_allocator.device, + ), + extend_num_tokens=num_tokens, + ) assert kv_loc is not None self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 692d014bb..9e080926a 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, + kv_to_page_indices, + kv_to_page_num, poll_and_all_reduce, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch @@ -154,7 +156,8 @@ class PrefillBootstrapQueue: self.req_to_metadata_buffer_idx_allocator.alloc() ) assert req.metadata_buffer_index is not None - req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index) + num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) + req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) bootstrapped_reqs.append(req) indices_to_remove.add(i) @@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin: req.metadata_buffer_index, token_id ) is_last = token_id is not None - req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last) + page_indices = kv_to_page_indices( + kv_indices, self.token_to_kv_pool_allocator.page_size + ) + req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 54d344416..7b43566f1 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -4,6 +4,7 @@ from collections import deque from enum import Enum from typing import List +import numpy as np import torch import torch.distributed as dist @@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): } return class_mapping.get(class_type) raise ValueError(f"Unsupported transfer backend: {transfer_backend}") + + +def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): + # 1. The page is guaruanteed to be full except the last page. + # 2. page index = kv_index // page_size + # The return vector is kv_indices[::page_size] // page_size + if page_size == 1: # shortcut + return kv_indices + return kv_indices[::page_size] // page_size + + +def kv_to_page_num(num_kv_indices: int, page_size: int): + # ceil(num_kv_indices / page_size) + return (num_kv_indices + page_size - 1) // page_size diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b35419578..37882da72 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache): self.get_key_buffer(i).nbytes for i in range(self.layer_num) ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] kv_item_lens = [ - self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] + self.get_key_buffer(i)[0].nbytes * self.page_size + for i in range(self.layer_num) + ] + [ + self.get_value_buffer(i)[0].nbytes * self.page_size + for i in range(self.layer_num) + ] return kv_data_ptrs, kv_data_lens, kv_item_lens # Todo: different memory layout