[PD] Support decode retract and update decode.py (#7196)
This commit is contained in:
@@ -31,7 +31,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
FAKE_BOOTSTRAP_HOST,
|
FAKE_BOOTSTRAP_HOST,
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
FINISH_ABORT,
|
||||||
|
ScheduleBatch,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
KVCache,
|
||||||
|
ReqToTokenPool,
|
||||||
|
TokenToKVPoolAllocator,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
@@ -145,7 +153,11 @@ class DecodePreallocQueue:
|
|||||||
gloo_group: ProcessGroup,
|
gloo_group: ProcessGroup,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
|
dp_size: int,
|
||||||
|
gpu_id: int,
|
||||||
bootstrap_port: int,
|
bootstrap_port: int,
|
||||||
|
max_total_num_tokens: int,
|
||||||
|
prefill_pp_size: int,
|
||||||
transfer_backend: TransferBackend,
|
transfer_backend: TransferBackend,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
@@ -161,25 +173,35 @@ class DecodePreallocQueue:
|
|||||||
self.gloo_group = gloo_group
|
self.gloo_group = gloo_group
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
self.dp_size = dp_size
|
||||||
|
self.gpu_id = gpu_id
|
||||||
self.bootstrap_port = bootstrap_port
|
self.bootstrap_port = bootstrap_port
|
||||||
|
self.max_total_num_tokens = max_total_num_tokens
|
||||||
|
self.prefill_pp_size = prefill_pp_size
|
||||||
self.num_reserved_decode_tokens = int(
|
self.num_reserved_decode_tokens = int(
|
||||||
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
||||||
)
|
)
|
||||||
|
self.transfer_backend = transfer_backend
|
||||||
# Queue for requests pending pre-allocation
|
# Queue for requests pending pre-allocation
|
||||||
self.queue: List[DecodeRequest] = []
|
self.queue: List[DecodeRequest] = []
|
||||||
self.transfer_backend = transfer_backend
|
self.retracted_queue: List[Req] = []
|
||||||
|
self.prefill_pp_size = prefill_pp_size
|
||||||
self.kv_manager = self._init_kv_manager()
|
self.kv_manager = self._init_kv_manager()
|
||||||
|
|
||||||
def _init_kv_manager(self) -> BaseKVManager:
|
def _init_kv_manager(self) -> BaseKVManager:
|
||||||
kv_args = KVArgs()
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||||
kv_args.engine_rank = self.tp_rank
|
kv_args = kv_args_class()
|
||||||
|
|
||||||
|
attn_tp_size = self.tp_size // self.dp_size
|
||||||
|
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
||||||
|
kv_args.decode_tp_size = attn_tp_size
|
||||||
|
kv_args.prefill_pp_size = self.prefill_pp_size
|
||||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.draft_token_to_kv_pool is not None:
|
if self.draft_token_to_kv_pool is not None:
|
||||||
|
# We should also transfer draft model kv cache. The indices are
|
||||||
|
# always shared with a target model.
|
||||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
@@ -194,6 +216,7 @@ class DecodePreallocQueue:
|
|||||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||||
self.metadata_buffers.get_buf_infos()
|
self.metadata_buffers.get_buf_infos()
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
@@ -205,27 +228,83 @@ class DecodePreallocQueue:
|
|||||||
)
|
)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
def add(self, req: Req) -> None:
|
def add(self, req: Req, is_retracted: bool = False) -> None:
|
||||||
"""Add a request to the pending queue."""
|
"""Add a request to the pending queue."""
|
||||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
if self._check_if_req_exceed_kv_capacity(req):
|
||||||
# Fake transfer for warmup reqs
|
return
|
||||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
|
||||||
else:
|
|
||||||
kv_receiver_class = get_kv_class(
|
|
||||||
self.transfer_backend, KVClassType.RECEIVER
|
|
||||||
)
|
|
||||||
kv_receiver = kv_receiver_class(
|
|
||||||
mgr=self.kv_manager,
|
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
|
||||||
bootstrap_room=req.bootstrap_room,
|
|
||||||
data_parallel_rank=req.data_parallel_rank,
|
|
||||||
)
|
|
||||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
|
||||||
|
|
||||||
def extend(self, reqs: List[Req]) -> None:
|
if is_retracted:
|
||||||
|
self.retracted_queue.append(req)
|
||||||
|
else:
|
||||||
|
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||||
|
kv_receiver_class = get_kv_class(
|
||||||
|
TransferBackend.FAKE, KVClassType.RECEIVER
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_receiver_class = get_kv_class(
|
||||||
|
self.transfer_backend, KVClassType.RECEIVER
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_receiver = kv_receiver_class(
|
||||||
|
mgr=self.kv_manager,
|
||||||
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||||
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.queue.append(
|
||||||
|
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||||
|
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||||
|
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||||
|
logger.error(message)
|
||||||
|
prepare_abort(req, message)
|
||||||
|
self.scheduler.stream_output([req], req.return_logprob)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
||||||
"""Add a request to the pending queue."""
|
"""Add a request to the pending queue."""
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
self.add(req)
|
self.add(req, is_retracted=is_retracted)
|
||||||
|
|
||||||
|
def resume_retracted_reqs(self) -> List[Req]:
|
||||||
|
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
||||||
|
|
||||||
|
# allocate memory
|
||||||
|
resumed_reqs = []
|
||||||
|
indices_to_remove = set()
|
||||||
|
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
||||||
|
|
||||||
|
for i, req in enumerate(self.retracted_queue):
|
||||||
|
if self.req_to_token_pool.available_size() <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
required_tokens_for_request = (
|
||||||
|
len(req.origin_input_ids)
|
||||||
|
+ len(req.output_ids)
|
||||||
|
+ self.num_reserved_decode_tokens
|
||||||
|
)
|
||||||
|
if required_tokens_for_request > allocatable_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
resumed_reqs.append(req)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
req.is_retracted = False
|
||||||
|
self._pre_alloc(req)
|
||||||
|
allocatable_tokens -= required_tokens_for_request
|
||||||
|
|
||||||
|
# load from cpu, release the cpu copy
|
||||||
|
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
||||||
|
|
||||||
|
self.retracted_queue = [
|
||||||
|
entry
|
||||||
|
for i, entry in enumerate(self.retracted_queue)
|
||||||
|
if i not in indices_to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
return resumed_reqs
|
||||||
|
|
||||||
def _update_handshake_waiters(self) -> None:
|
def _update_handshake_waiters(self) -> None:
|
||||||
if not self.queue:
|
if not self.queue:
|
||||||
@@ -255,6 +334,8 @@ class DecodePreallocQueue:
|
|||||||
error_message,
|
error_message,
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected poll case: {poll}")
|
||||||
|
|
||||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||||
@@ -262,8 +343,16 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
preallocated_reqs = []
|
preallocated_reqs = []
|
||||||
indices_to_remove = set()
|
indices_to_remove = set()
|
||||||
allocatable_tokens = self._allocatable_tokens()
|
|
||||||
|
|
||||||
|
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
||||||
|
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
||||||
|
retractable_tokens = sum(
|
||||||
|
len(r.origin_input_ids) + len(r.output_ids)
|
||||||
|
for r in self.scheduler.running_batch.reqs
|
||||||
|
)
|
||||||
|
allocatable_tokens = self._allocatable_tokens(
|
||||||
|
retractable_tokens=retractable_tokens, count_retracted=True
|
||||||
|
)
|
||||||
# First, remove all failed requests from the queue
|
# First, remove all failed requests from the queue
|
||||||
for i, decode_req in enumerate(self.queue):
|
for i, decode_req in enumerate(self.queue):
|
||||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||||
@@ -272,6 +361,7 @@ class DecodePreallocQueue:
|
|||||||
)
|
)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
|
# Then, preallocate the remaining requests if possible
|
||||||
for i, decode_req in enumerate(self.queue):
|
for i, decode_req in enumerate(self.queue):
|
||||||
if i in indices_to_remove:
|
if i in indices_to_remove:
|
||||||
continue
|
continue
|
||||||
@@ -285,10 +375,23 @@ class DecodePreallocQueue:
|
|||||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Memory estimation: don't add if the projected memory cannot be met
|
||||||
|
# TODO: add new_token ratio
|
||||||
|
origin_input_len = len(decode_req.req.origin_input_ids)
|
||||||
required_tokens_for_request = (
|
required_tokens_for_request = (
|
||||||
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
origin_input_len + self.num_reserved_decode_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
max(
|
||||||
|
required_tokens_for_request,
|
||||||
|
origin_input_len
|
||||||
|
+ decode_req.req.sampling_params.max_new_tokens
|
||||||
|
- retractable_tokens,
|
||||||
|
)
|
||||||
|
> allocatable_tokens
|
||||||
|
):
|
||||||
|
break
|
||||||
if required_tokens_for_request > allocatable_tokens:
|
if required_tokens_for_request > allocatable_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -321,15 +424,35 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
return preallocated_reqs
|
return preallocated_reqs
|
||||||
|
|
||||||
def _allocatable_tokens(self) -> int:
|
def _allocatable_tokens(
|
||||||
allocatable_tokens = (
|
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
||||||
self.token_to_kv_pool_allocator.available_size()
|
) -> int:
|
||||||
- self.num_reserved_decode_tokens
|
need_space_for_single_req = (
|
||||||
|
max(
|
||||||
|
[
|
||||||
|
x.sampling_params.max_new_tokens
|
||||||
|
+ len(x.origin_input_ids)
|
||||||
|
- retractable_tokens
|
||||||
|
for x in self.scheduler.running_batch.reqs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if retractable_tokens is not None
|
||||||
|
and len(self.scheduler.running_batch.reqs) > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||||
|
|
||||||
|
allocatable_tokens = available_size - max(
|
||||||
|
# preserve some space for future decode
|
||||||
|
self.num_reserved_decode_tokens
|
||||||
* (
|
* (
|
||||||
len(self.scheduler.running_batch.reqs)
|
len(self.scheduler.running_batch.reqs)
|
||||||
+ len(self.transfer_queue.queue)
|
+ len(self.transfer_queue.queue)
|
||||||
+ len(self.scheduler.waiting_queue)
|
+ len(self.scheduler.waiting_queue)
|
||||||
)
|
),
|
||||||
|
# make sure each request can finish if reach max_tokens with all other requests retracted
|
||||||
|
need_space_for_single_req,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||||
@@ -342,15 +465,27 @@ class DecodePreallocQueue:
|
|||||||
self.scheduler.last_batch.reqs
|
self.scheduler.last_batch.reqs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if count_retracted:
|
||||||
|
allocatable_tokens -= sum(
|
||||||
|
[
|
||||||
|
len(req.origin_input_ids)
|
||||||
|
+ len(req.output_ids)
|
||||||
|
+ self.num_reserved_decode_tokens
|
||||||
|
for req in self.retracted_queue
|
||||||
|
]
|
||||||
|
)
|
||||||
return allocatable_tokens
|
return allocatable_tokens
|
||||||
|
|
||||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||||
|
|
||||||
assert req_pool_indices is not None
|
assert (
|
||||||
|
req_pool_indices is not None
|
||||||
|
), "req_pool_indices is full! There is a bug in memory estimation."
|
||||||
|
|
||||||
req.req_pool_idx = req_pool_indices[0]
|
req.req_pool_idx = req_pool_indices[0]
|
||||||
|
|
||||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
@@ -375,7 +510,10 @@ class DecodePreallocQueue:
|
|||||||
),
|
),
|
||||||
extend_num_tokens=num_tokens,
|
extend_num_tokens=num_tokens,
|
||||||
)
|
)
|
||||||
assert kv_loc is not None
|
|
||||||
|
assert (
|
||||||
|
kv_loc is not None
|
||||||
|
), "KV cache is full! There is a bug in memory estimation."
|
||||||
|
|
||||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||||
|
|
||||||
@@ -395,6 +533,7 @@ class DecodeTransferQueue:
|
|||||||
self,
|
self,
|
||||||
gloo_group: ProcessGroup,
|
gloo_group: ProcessGroup,
|
||||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
|
tp_rank: int,
|
||||||
metadata_buffers: MetadataBuffers,
|
metadata_buffers: MetadataBuffers,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
@@ -402,6 +541,7 @@ class DecodeTransferQueue:
|
|||||||
self.queue: List[DecodeRequest] = []
|
self.queue: List[DecodeRequest] = []
|
||||||
self.gloo_group = gloo_group
|
self.gloo_group = gloo_group
|
||||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
|
self.tp_rank = tp_rank
|
||||||
self.metadata_buffers = metadata_buffers
|
self.metadata_buffers = metadata_buffers
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
@@ -412,10 +552,9 @@ class DecodeTransferQueue:
|
|||||||
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
||||||
self.queue.extend(decode_reqs)
|
self.queue.extend(decode_reqs)
|
||||||
|
|
||||||
def pop_transferred(self) -> List[DecodeRequest]:
|
def pop_transferred(self) -> List[Req]:
|
||||||
if not self.queue:
|
if not self.queue:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
polls = poll_and_all_reduce(
|
polls = poll_and_all_reduce(
|
||||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||||
)
|
)
|
||||||
@@ -424,7 +563,7 @@ class DecodeTransferQueue:
|
|||||||
indices_to_remove = set()
|
indices_to_remove = set()
|
||||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
if poll == KVPoll.Failed:
|
if poll == KVPoll.Failed:
|
||||||
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||||
try:
|
try:
|
||||||
decode_req.kv_receiver.failure_exception()
|
decode_req.kv_receiver.failure_exception()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||||
|
|
||||||
if batch is None and (
|
if batch is None and (
|
||||||
len(self.disagg_decode_transfer_queue.queue)
|
len(self.waiting_queue)
|
||||||
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
if batch is None and (
|
if batch is None and (
|
||||||
len(self.disagg_decode_transfer_queue.queue)
|
len(self.waiting_queue)
|
||||||
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def process_decode_queue(self: Scheduler):
|
def process_decode_queue(self: Scheduler):
|
||||||
|
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
||||||
|
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
||||||
|
self.waiting_queue.extend(resumed_reqs)
|
||||||
|
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
||||||
|
# if there are still retracted requests, we do not allocate new requests
|
||||||
|
return
|
||||||
|
|
||||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||||
alloc_reqs = (
|
alloc_reqs = (
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from collections import deque
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||||
@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy()
|
.numpy()
|
||||||
|
.astype(np.int64)
|
||||||
)
|
)
|
||||||
req.start_send_idx = end_idx
|
req.start_send_idx = end_idx
|
||||||
if last_chunk:
|
if last_chunk:
|
||||||
|
|||||||
@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
req = self.reqs[idx]
|
req = self.reqs[idx]
|
||||||
retracted_reqs.append(req)
|
retracted_reqs.append(req)
|
||||||
|
|
||||||
|
if server_args.disaggregation_mode == "decode":
|
||||||
|
req.offload_kv_cache(
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(self.tree_cache, ChunkCache):
|
if isinstance(self.tree_cache, ChunkCache):
|
||||||
# ChunkCache does not have eviction
|
# ChunkCache does not have eviction
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
req.reset_for_retract()
|
req.reset_for_retract()
|
||||||
|
|
||||||
|
if len(retracted_reqs) == 0:
|
||||||
|
# Corner case: only one request left
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to retract any request. No space left for only one request."
|
||||||
|
)
|
||||||
|
|
||||||
self.filter_batch(keep_indices=sorted_indices)
|
self.filter_batch(keep_indices=sorted_indices)
|
||||||
|
|
||||||
# Reqs in batch are filtered
|
# Reqs in batch are filtered
|
||||||
|
|||||||
@@ -628,6 +628,7 @@ class Scheduler(
|
|||||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||||
gloo_group=self.attn_tp_cpu_group,
|
gloo_group=self.attn_tp_cpu_group,
|
||||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
metadata_buffers=self.disagg_metadata_buffers,
|
metadata_buffers=self.disagg_metadata_buffers,
|
||||||
scheduler=self,
|
scheduler=self,
|
||||||
tree_cache=self.tree_cache,
|
tree_cache=self.tree_cache,
|
||||||
@@ -650,7 +651,11 @@ class Scheduler(
|
|||||||
gloo_group=self.attn_tp_cpu_group,
|
gloo_group=self.attn_tp_cpu_group,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
|
dp_size=self.server_args.dp_size,
|
||||||
|
gpu_id=self.gpu_id,
|
||||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
|
max_total_num_tokens=self.max_total_num_tokens,
|
||||||
|
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
||||||
transfer_backend=self.transfer_backend,
|
transfer_backend=self.transfer_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1124,14 +1129,14 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.disagg_prefill_bootstrap_queue.extend(
|
self.disagg_prefill_bootstrap_queue.extend(
|
||||||
reqs, self.model_config.num_key_value_heads
|
reqs, self.model_config.num_key_value_heads
|
||||||
)
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||||
else:
|
else:
|
||||||
self.waiting_queue.extend(reqs)
|
self.waiting_queue.extend(reqs)
|
||||||
|
|
||||||
@@ -1274,6 +1279,7 @@ class Scheduler(
|
|||||||
|
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||||
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||||
|
|
||||||
msg += (
|
msg += (
|
||||||
f"cuda graph: {can_run_cuda_graph}, "
|
f"cuda graph: {can_run_cuda_graph}, "
|
||||||
@@ -1575,7 +1581,7 @@ class Scheduler(
|
|||||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||||
)
|
)
|
||||||
self._extend_requests_to_queue(retracted_reqs)
|
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
||||||
else:
|
else:
|
||||||
self.new_token_ratio = max(
|
self.new_token_ratio = max(
|
||||||
self.new_token_ratio - self.new_token_ratio_decay,
|
self.new_token_ratio - self.new_token_ratio_decay,
|
||||||
|
|||||||
@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
|
|||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
|
||||||
|
def get_cpu_copy(self, indices):
|
||||||
|
return self._kvcache.get_cpu_copy(indices)
|
||||||
|
|
||||||
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(KVCache):
|
class MHATokenToKVPool(KVCache):
|
||||||
|
|
||||||
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
|
|
||||||
|
# used for chunked cpu-offloading
|
||||||
|
self.chunk_size = 8192
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.device_module = torch.get_device_module(self.device)
|
self.device_module = torch.get_device_module(self.device)
|
||||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||||
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
|
|||||||
]
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
|
def get_cpu_copy(self, indices):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
kv_cache_cpu = []
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
kv_cache_cpu.append([])
|
||||||
|
for i in range(0, len(indices), self.chunk_size):
|
||||||
|
chunk_indices = indices[i : i + self.chunk_size]
|
||||||
|
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
|
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
|
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return kv_cache_cpu
|
||||||
|
|
||||||
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for layer_id in range(self.layer_num):
|
||||||
|
for i in range(0, len(indices), self.chunk_size):
|
||||||
|
chunk_indices = indices[i : i + self.chunk_size]
|
||||||
|
k_cpu, v_cpu = (
|
||||||
|
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
||||||
|
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
||||||
|
)
|
||||||
|
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||||
|
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||||
|
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
||||||
|
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
||||||
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Todo: different memory layout
|
# Todo: different memory layout
|
||||||
def get_flat_data(self, indices):
|
def get_flat_data(self, indices):
|
||||||
# prepare a large chunk of contiguous data for efficient transfer
|
# prepare a large chunk of contiguous data for efficient transfer
|
||||||
|
|||||||
@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.20)
|
self.assertGreater(metrics["accuracy"], 0.20)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisaggregationSimulatedRetract(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
cls.base_host = parsed_url.hostname
|
||||||
|
base_port = str(parsed_url.port)
|
||||||
|
cls.lb_port = base_port
|
||||||
|
cls.prefill_port = f"{int(base_port) + 100}"
|
||||||
|
cls.decode_port = f"{int(base_port) + 200}"
|
||||||
|
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||||
|
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||||
|
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||||
|
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||||
|
|
||||||
|
# Non blocking start servers
|
||||||
|
cls.start_prefill()
|
||||||
|
cls.start_decode()
|
||||||
|
|
||||||
|
# Block until both
|
||||||
|
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||||
|
cls.wait_server_ready(cls.decode_url + "/health")
|
||||||
|
|
||||||
|
lb_command = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.srt.disaggregation.mini_lb",
|
||||||
|
"--prefill",
|
||||||
|
cls.prefill_url,
|
||||||
|
"--decode",
|
||||||
|
cls.decode_url,
|
||||||
|
"--host",
|
||||||
|
cls.base_host,
|
||||||
|
"--port",
|
||||||
|
cls.lb_port,
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
|
cls.process_lb = subprocess.Popen(
|
||||||
|
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_prefill(cls):
|
||||||
|
prefill_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--tp",
|
||||||
|
"1",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce0",
|
||||||
|
]
|
||||||
|
cls.process_prefill = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.prefill_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=prefill_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--tp",
|
||||||
|
"1",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"1",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce1",
|
||||||
|
]
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"Server {url} is ready")
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if time.perf_counter() - start_time > timeout:
|
||||||
|
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
os.environ.pop("SGLANG_TEST_RETRACT")
|
||||||
|
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
||||||
|
if process:
|
||||||
|
try:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error killing process {process.pid}: {e}")
|
||||||
|
|
||||||
|
# wait for 5 seconds
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host=f"http://{self.base_host}",
|
||||||
|
port=int(self.lb_port),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"Evaluation metrics: {metrics}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.62)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user