[PD] Support page size > 1 (#5561)
This commit is contained in:
@@ -35,6 +35,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
get_kv_class,
|
get_kv_class,
|
||||||
|
kv_to_page_indices,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
@@ -205,7 +206,10 @@ class DecodePreallocQueue:
|
|||||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
)
|
)
|
||||||
assert decode_req.metadata_buffer_index is not None
|
assert decode_req.metadata_buffer_index is not None
|
||||||
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
page_indices = kv_to_page_indices(
|
||||||
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||||
|
)
|
||||||
|
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
||||||
preallocated_reqs.append(decode_req)
|
preallocated_reqs.append(decode_req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
@@ -245,10 +249,30 @@ class DecodePreallocQueue:
|
|||||||
assert req_pool_indices is not None
|
assert req_pool_indices is not None
|
||||||
|
|
||||||
req.req_pool_idx = req_pool_indices[0]
|
req.req_pool_idx = req_pool_indices[0]
|
||||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||||
)
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
|
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||||
|
prefix_lens=torch.tensor(
|
||||||
|
[0],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.token_to_kv_pool_allocator.device,
|
||||||
|
),
|
||||||
|
seq_lens=torch.tensor(
|
||||||
|
[num_tokens],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.token_to_kv_pool_allocator.device,
|
||||||
|
),
|
||||||
|
last_loc=torch.tensor(
|
||||||
|
[-1],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.token_to_kv_pool_allocator.device,
|
||||||
|
),
|
||||||
|
extend_num_tokens=num_tokens,
|
||||||
|
)
|
||||||
assert kv_loc is not None
|
assert kv_loc is not None
|
||||||
|
|
||||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
get_kv_class,
|
get_kv_class,
|
||||||
|
kv_to_page_indices,
|
||||||
|
kv_to_page_num,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
@@ -154,7 +156,8 @@ class PrefillBootstrapQueue:
|
|||||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
)
|
)
|
||||||
assert req.metadata_buffer_index is not None
|
assert req.metadata_buffer_index is not None
|
||||||
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||||
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||||
|
|
||||||
bootstrapped_reqs.append(req)
|
bootstrapped_reqs.append(req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
@@ -300,4 +303,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
req.metadata_buffer_index, token_id
|
req.metadata_buffer_index, token_id
|
||||||
)
|
)
|
||||||
is_last = token_id is not None
|
is_last = token_id is not None
|
||||||
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
|
page_indices = kv_to_page_indices(
|
||||||
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||||
|
)
|
||||||
|
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from collections import deque
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@@ -73,3 +74,17 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|||||||
}
|
}
|
||||||
return class_mapping.get(class_type)
|
return class_mapping.get(class_type)
|
||||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||||
|
|
||||||
|
|
||||||
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||||
|
# 1. The page is guaruanteed to be full except the last page.
|
||||||
|
# 2. page index = kv_index // page_size
|
||||||
|
# The return vector is kv_indices[::page_size] // page_size
|
||||||
|
if page_size == 1: # shortcut
|
||||||
|
return kv_indices
|
||||||
|
return kv_indices[::page_size] // page_size
|
||||||
|
|
||||||
|
|
||||||
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||||
|
# ceil(num_kv_indices / page_size)
|
||||||
|
return (num_kv_indices + page_size - 1) // page_size
|
||||||
|
|||||||
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||||
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||||
kv_item_lens = [
|
kv_item_lens = [
|
||||||
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
self.get_key_buffer(i)[0].nbytes * self.page_size
|
||||||
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
for i in range(self.layer_num)
|
||||||
|
] + [
|
||||||
|
self.get_value_buffer(i)[0].nbytes * self.page_size
|
||||||
|
for i in range(self.layer_num)
|
||||||
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
# Todo: different memory layout
|
# Todo: different memory layout
|
||||||
|
|||||||
Reference in New Issue
Block a user