[PD] Support decode retract and update decode.py (#7196)
This commit is contained in:
@@ -31,7 +31,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
KVCache,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
@@ -145,7 +153,11 @@ class DecodePreallocQueue:
|
||||
gloo_group: ProcessGroup,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
gpu_id: int,
|
||||
bootstrap_port: int,
|
||||
max_total_num_tokens: int,
|
||||
prefill_pp_size: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
@@ -161,25 +173,35 @@ class DecodePreallocQueue:
|
||||
self.gloo_group = gloo_group
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.dp_size = dp_size
|
||||
self.gpu_id = gpu_id
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.num_reserved_decode_tokens = int(
|
||||
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
||||
)
|
||||
|
||||
self.transfer_backend = transfer_backend
|
||||
# Queue for requests pending pre-allocation
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.transfer_backend = transfer_backend
|
||||
self.retracted_queue: List[Req] = []
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||
kv_args = kv_args_class()
|
||||
|
||||
attn_tp_size = self.tp_size // self.dp_size
|
||||
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
||||
kv_args.decode_tp_size = attn_tp_size
|
||||
kv_args.prefill_pp_size = self.prefill_pp_size
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
@@ -194,6 +216,7 @@ class DecodePreallocQueue:
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
@@ -205,27 +228,83 @@ class DecodePreallocQueue:
|
||||
)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
def add(self, req: Req, is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
if self._check_if_req_exceed_kv_capacity(req):
|
||||
return
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
if is_retracted:
|
||||
self.retracted_queue.append(req)
|
||||
else:
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
kv_receiver_class = get_kv_class(
|
||||
TransferBackend.FAKE, KVClassType.RECEIVER
|
||||
)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
)
|
||||
|
||||
self.queue.append(
|
||||
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
||||
)
|
||||
|
||||
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||
logger.error(message)
|
||||
prepare_abort(req, message)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
return True
|
||||
return False
|
||||
|
||||
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
self.add(req, is_retracted=is_retracted)
|
||||
|
||||
def resume_retracted_reqs(self) -> List[Req]:
|
||||
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
||||
|
||||
# allocate memory
|
||||
resumed_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
||||
|
||||
for i, req in enumerate(self.retracted_queue):
|
||||
if self.req_to_token_pool.available_size() <= 0:
|
||||
break
|
||||
|
||||
required_tokens_for_request = (
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
)
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
resumed_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
req.is_retracted = False
|
||||
self._pre_alloc(req)
|
||||
allocatable_tokens -= required_tokens_for_request
|
||||
|
||||
# load from cpu, release the cpu copy
|
||||
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
||||
|
||||
self.retracted_queue = [
|
||||
entry
|
||||
for i, entry in enumerate(self.retracted_queue)
|
||||
if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return resumed_reqs
|
||||
|
||||
def _update_handshake_waiters(self) -> None:
|
||||
if not self.queue:
|
||||
@@ -255,6 +334,8 @@ class DecodePreallocQueue:
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected poll case: {poll}")
|
||||
|
||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||
@@ -262,8 +343,16 @@ class DecodePreallocQueue:
|
||||
|
||||
preallocated_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens()
|
||||
|
||||
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
||||
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
||||
retractable_tokens = sum(
|
||||
len(r.origin_input_ids) + len(r.output_ids)
|
||||
for r in self.scheduler.running_batch.reqs
|
||||
)
|
||||
allocatable_tokens = self._allocatable_tokens(
|
||||
retractable_tokens=retractable_tokens, count_retracted=True
|
||||
)
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
@@ -272,6 +361,7 @@ class DecodePreallocQueue:
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
# Then, preallocate the remaining requests if possible
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if i in indices_to_remove:
|
||||
continue
|
||||
@@ -285,10 +375,23 @@ class DecodePreallocQueue:
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||
break
|
||||
|
||||
# Memory estimation: don't add if the projected memory cannot be met
|
||||
# TODO: add new_token ratio
|
||||
origin_input_len = len(decode_req.req.origin_input_ids)
|
||||
required_tokens_for_request = (
|
||||
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
||||
origin_input_len + self.num_reserved_decode_tokens
|
||||
)
|
||||
|
||||
if (
|
||||
max(
|
||||
required_tokens_for_request,
|
||||
origin_input_len
|
||||
+ decode_req.req.sampling_params.max_new_tokens
|
||||
- retractable_tokens,
|
||||
)
|
||||
> allocatable_tokens
|
||||
):
|
||||
break
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
@@ -321,15 +424,35 @@ class DecodePreallocQueue:
|
||||
|
||||
return preallocated_reqs
|
||||
|
||||
def _allocatable_tokens(self) -> int:
|
||||
allocatable_tokens = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
- self.num_reserved_decode_tokens
|
||||
def _allocatable_tokens(
|
||||
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
||||
) -> int:
|
||||
need_space_for_single_req = (
|
||||
max(
|
||||
[
|
||||
x.sampling_params.max_new_tokens
|
||||
+ len(x.origin_input_ids)
|
||||
- retractable_tokens
|
||||
for x in self.scheduler.running_batch.reqs
|
||||
]
|
||||
)
|
||||
if retractable_tokens is not None
|
||||
and len(self.scheduler.running_batch.reqs) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
available_size = self.token_to_kv_pool_allocator.available_size()
|
||||
|
||||
allocatable_tokens = available_size - max(
|
||||
# preserve some space for future decode
|
||||
self.num_reserved_decode_tokens
|
||||
* (
|
||||
len(self.scheduler.running_batch.reqs)
|
||||
+ len(self.transfer_queue.queue)
|
||||
+ len(self.scheduler.waiting_queue)
|
||||
)
|
||||
),
|
||||
# make sure each request can finish if reach max_tokens with all other requests retracted
|
||||
need_space_for_single_req,
|
||||
)
|
||||
|
||||
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||
@@ -342,15 +465,27 @@ class DecodePreallocQueue:
|
||||
self.scheduler.last_batch.reqs
|
||||
)
|
||||
|
||||
if count_retracted:
|
||||
allocatable_tokens -= sum(
|
||||
[
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
+ self.num_reserved_decode_tokens
|
||||
for req in self.retracted_queue
|
||||
]
|
||||
)
|
||||
return allocatable_tokens
|
||||
|
||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
|
||||
assert req_pool_indices is not None
|
||||
assert (
|
||||
req_pool_indices is not None
|
||||
), "req_pool_indices is full! There is a bug in memory estimation."
|
||||
|
||||
req.req_pool_idx = req_pool_indices[0]
|
||||
|
||||
if self.token_to_kv_pool_allocator.page_size == 1:
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
@@ -375,7 +510,10 @@ class DecodePreallocQueue:
|
||||
),
|
||||
extend_num_tokens=num_tokens,
|
||||
)
|
||||
assert kv_loc is not None
|
||||
|
||||
assert (
|
||||
kv_loc is not None
|
||||
), "KV cache is full! There is a bug in memory estimation."
|
||||
|
||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||
|
||||
@@ -395,6 +533,7 @@ class DecodeTransferQueue:
|
||||
self,
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
tp_rank: int,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
tree_cache: BasePrefixCache,
|
||||
@@ -402,6 +541,7 @@ class DecodeTransferQueue:
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
@@ -412,10 +552,9 @@ class DecodeTransferQueue:
|
||||
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
||||
self.queue.extend(decode_reqs)
|
||||
|
||||
def pop_transferred(self) -> List[DecodeRequest]:
|
||||
def pop_transferred(self) -> List[Req]:
|
||||
if not self.queue:
|
||||
return []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
@@ -424,7 +563,7 @@ class DecodeTransferQueue:
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Failed:
|
||||
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
return new_batch
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
||||
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
||||
self.waiting_queue.extend(resumed_reqs)
|
||||
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
||||
# if there are still retracted requests, we do not allocate new requests
|
||||
return
|
||||
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
|
||||
@@ -25,6 +25,7 @@ from collections import deque
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||
@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
|
||||
@@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
@@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
# Corner case: only one request left
|
||||
raise ValueError(
|
||||
"Failed to retract any request. No space left for only one request."
|
||||
)
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
|
||||
# Reqs in batch are filtered
|
||||
|
||||
@@ -628,6 +628,7 @@ class Scheduler(
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
tp_rank=self.tp_rank,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -650,7 +651,11 @@ class Scheduler(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
gpu_id=self.gpu_id,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
|
||||
@@ -1124,14 +1129,14 @@ class Scheduler(
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(
|
||||
reqs, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
@@ -1274,6 +1279,7 @@ class Scheduler(
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
||||
|
||||
msg += (
|
||||
f"cuda graph: {can_run_cuda_graph}, "
|
||||
@@ -1575,7 +1581,7 @@ class Scheduler(
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self._extend_requests_to_queue(retracted_reqs)
|
||||
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
|
||||
@@ -234,6 +234,12 @@ class TokenToKVPoolAllocator:
|
||||
self.is_not_in_free_group = True
|
||||
self.free_group = []
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
return self._kvcache.get_cpu_copy(indices)
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
||||
|
||||
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
@@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache):
|
||||
self.head_dim = head_dim
|
||||
self._create_buffers()
|
||||
|
||||
# used for chunked cpu-offloading
|
||||
self.chunk_size = 8192
|
||||
self.layer_transfer_counter = None
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
||||
@@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache):
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
for layer_id in range(self.layer_num):
|
||||
kv_cache_cpu.append([])
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
||||
torch.cuda.synchronize()
|
||||
return kv_cache_cpu
|
||||
|
||||
def load_cpu_copy(self, kv_cache_cpu, indices):
|
||||
torch.cuda.synchronize()
|
||||
for layer_id in range(self.layer_num):
|
||||
for i in range(0, len(indices), self.chunk_size):
|
||||
chunk_indices = indices[i : i + self.chunk_size]
|
||||
k_cpu, v_cpu = (
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
||||
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
||||
)
|
||||
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
||||
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
||||
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
||||
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
|
||||
@@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
|
||||
self.assertGreater(metrics["accuracy"], 0.20)
|
||||
|
||||
|
||||
class TestDisaggregationSimulatedRetract(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
os.environ["SGLANG_TEST_RETRACT"] = "true"
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.base_host = parsed_url.hostname
|
||||
base_port = str(parsed_url.port)
|
||||
cls.lb_port = base_port
|
||||
cls.prefill_port = f"{int(base_port) + 100}"
|
||||
cls.decode_port = f"{int(base_port) + 200}"
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||
|
||||
# Non blocking start servers
|
||||
cls.start_prefill()
|
||||
cls.start_decode()
|
||||
|
||||
# Block until both
|
||||
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||
cls.wait_server_ready(cls.decode_url + "/health")
|
||||
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
cls.decode_url,
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
cls.lb_port,
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
def start_prefill(cls):
|
||||
prefill_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--tp",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce0",
|
||||
]
|
||||
cls.process_prefill = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.prefill_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=prefill_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def start_decode(cls):
|
||||
decode_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--tp",
|
||||
"1",
|
||||
"--base-gpu-id",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce1",
|
||||
]
|
||||
cls.process_decode = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.decode_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=decode_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
print(f"Server {url} is ready")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
||||
time.sleep(1)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
os.environ.pop("SGLANG_TEST_RETRACT")
|
||||
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
||||
if process:
|
||||
try:
|
||||
kill_process_tree(process.pid)
|
||||
except Exception as e:
|
||||
print(f"Error killing process {process.pid}: {e}")
|
||||
|
||||
# wait for 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.lb_port),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Evaluation metrics: {metrics}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.62)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user