[PD] Support page size > 1 (#5561)

This commit is contained in:
Byron Hsu
2025-04-19 21:54:27 -07:00
committed by GitHub
parent 20f1c8e374
commit ab4b5606e4
4 changed files with 58 additions and 9 deletions

View File

@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
kv_to_page_indices,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -205,7 +206,10 @@ class DecodePreallocQueue:
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert decode_req.metadata_buffer_index is not None assert decode_req.metadata_buffer_index is not None
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index) page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
@@ -245,10 +249,30 @@ class DecodePreallocQueue:
assert req_pool_indices is not None assert req_pool_indices is not None
req.req_pool_idx = req_pool_indices[0] req.req_pool_idx = req_pool_indices[0]
kv_loc = self.token_to_kv_pool_allocator.alloc( if self.token_to_kv_pool_allocator.page_size == 1:
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) kv_loc = self.token_to_kv_pool_allocator.alloc(
) len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
)
else:
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens=torch.tensor(
[0],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
seq_lens=torch.tensor(
[num_tokens],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
last_loc=torch.tensor(
[-1],
dtype=torch.int64,
device=self.token_to_kv_pool_allocator.device,
),
extend_num_tokens=num_tokens,
)
assert kv_loc is not None assert kv_loc is not None
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)

View File

@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
get_kv_class, get_kv_class,
kv_to_page_indices,
kv_to_page_num,
poll_and_all_reduce, poll_and_all_reduce,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert req.metadata_buffer_index is not None assert req.metadata_buffer_index is not None
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index) num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
bootstrapped_reqs.append(req) bootstrapped_reqs.append(req)
indices_to_remove.add(i) indices_to_remove.add(i)
@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
req.metadata_buffer_index, token_id req.metadata_buffer_index, token_id
) )
is_last = token_id is not None is_last = token_id is not None
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last) page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
)
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)

View File

@@ -4,6 +4,7 @@ from collections import deque
from enum import Enum from enum import Enum
from typing import List from typing import List
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}") raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
# 1. The page is guaruanteed to be full except the last page.
# 2. page index = kv_index // page_size
# The return vector is kv_indices[::page_size] // page_size
if page_size == 1: # shortcut
return kv_indices
return kv_indices[::page_size] // page_size
def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size

View File

@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
self.get_key_buffer(i).nbytes for i in range(self.layer_num) self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
kv_item_lens = [ kv_item_lens = [
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) self.get_key_buffer(i)[0].nbytes * self.page_size
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] for i in range(self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
# Todo: different memory layout # Todo: different memory layout