[PD] Support page size > 1 (#5561)
This commit is contained in:
@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
kv_to_page_indices,
|
||||
poll_and_all_reduce,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
@@ -205,7 +206,10 @@ class DecodePreallocQueue:
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert decode_req.metadata_buffer_index is not None
|
||||
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
||||
page_indices = kv_to_page_indices(
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
@@ -245,10 +249,30 @@ class DecodePreallocQueue:
|
||||
assert req_pool_indices is not None
|
||||
|
||||
req.req_pool_idx = req_pool_indices[0]
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
)
|
||||
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
)
|
||||
else:
|
||||
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens=torch.tensor(
|
||||
[0],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
seq_lens=torch.tensor(
|
||||
[num_tokens],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
last_loc=torch.tensor(
|
||||
[-1],
|
||||
dtype=torch.int64,
|
||||
device=self.token_to_kv_pool_allocator.device,
|
||||
),
|
||||
extend_num_tokens=num_tokens,
|
||||
)
|
||||
assert kv_loc is not None
|
||||
|
||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||
|
||||
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
kv_to_page_indices,
|
||||
kv_to_page_num,
|
||||
poll_and_all_reduce,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert req.metadata_buffer_index is not None
|
||||
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
|
||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
is_last = token_id is not None
|
||||
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
|
||||
page_indices = kv_to_page_indices(
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import deque
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
# 1. The page is guaruanteed to be full except the last page.
|
||||
# 2. page index = kv_index // page_size
|
||||
# The return vector is kv_indices[::page_size] // page_size
|
||||
if page_size == 1: # shortcut
|
||||
return kv_indices
|
||||
return kv_indices[::page_size] // page_size
|
||||
|
||||
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
# ceil(num_kv_indices / page_size)
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
|
||||
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
||||
self.get_key_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
# Todo: different memory layout
|
||||
|
||||
Reference in New Issue
Block a user