diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index e2cc25eeb..336b0581d 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -31,7 +31,7 @@ import numpy as np import torch from torch.distributed import ProcessGroup -from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll +from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.utils import ( FAKE_BOOTSTRAP_HOST, DisaggregationMode, @@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch +from sglang.srt.managers.schedule_batch import ( + FINISH_ABORT, + ScheduleBatch, + global_server_args_dict, +) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.mem_cache.memory_pool import ( + KVCache, + ReqToTokenPool, + TokenToKVPoolAllocator, +) from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -145,7 +153,11 @@ class DecodePreallocQueue: gloo_group: ProcessGroup, tp_rank: int, tp_size: int, + dp_size: int, + gpu_id: int, bootstrap_port: int, + max_total_num_tokens: int, + prefill_pp_size: int, transfer_backend: TransferBackend, ): self.req_to_token_pool = req_to_token_pool @@ -161,25 +173,35 @@ class DecodePreallocQueue: self.gloo_group = gloo_group self.tp_rank = tp_rank self.tp_size = tp_size + self.dp_size = dp_size + self.gpu_id = gpu_id self.bootstrap_port = bootstrap_port - + self.max_total_num_tokens = max_total_num_tokens + self.prefill_pp_size = prefill_pp_size self.num_reserved_decode_tokens = int( os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512") ) - + self.transfer_backend = transfer_backend # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] - self.transfer_backend = transfer_backend + self.retracted_queue: List[Req] = [] + self.prefill_pp_size = prefill_pp_size self.kv_manager = self._init_kv_manager() def _init_kv_manager(self) -> BaseKVManager: - kv_args = KVArgs() - kv_args.engine_rank = self.tp_rank + kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) + kv_args = kv_args_class() + + attn_tp_size = self.tp_size // self.dp_size + kv_args.engine_rank = self.tp_rank % (attn_tp_size) + kv_args.decode_tp_size = attn_tp_size + kv_args.prefill_pp_size = self.prefill_pp_size kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) - if self.draft_token_to_kv_pool is not None: + # We should also transfer draft model kv cache. The indices are + # always shared with a target model. draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = ( self.draft_token_to_kv_pool.get_contiguous_buf_infos() ) @@ -194,6 +216,7 @@ class DecodePreallocQueue: kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) + kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) @@ -205,27 +228,83 @@ class DecodePreallocQueue: ) return kv_manager - def add(self, req: Req) -> None: + def add(self, req: Req, is_retracted: bool = False) -> None: """Add a request to the pending queue.""" - if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: - # Fake transfer for warmup reqs - kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER) - else: - kv_receiver_class = get_kv_class( - self.transfer_backend, KVClassType.RECEIVER - ) - kv_receiver = kv_receiver_class( - mgr=self.kv_manager, - bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", - bootstrap_room=req.bootstrap_room, - data_parallel_rank=req.data_parallel_rank, - ) - self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) + if self._check_if_req_exceed_kv_capacity(req): + return - def extend(self, reqs: List[Req]) -> None: + if is_retracted: + self.retracted_queue.append(req) + else: + if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: + kv_receiver_class = get_kv_class( + TransferBackend.FAKE, KVClassType.RECEIVER + ) + else: + kv_receiver_class = get_kv_class( + self.transfer_backend, KVClassType.RECEIVER + ) + + kv_receiver = kv_receiver_class( + mgr=self.kv_manager, + bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", + bootstrap_room=req.bootstrap_room, + ) + + self.queue.append( + DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) + ) + + def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: + if len(req.origin_input_ids) > self.max_total_num_tokens: + message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" + logger.error(message) + prepare_abort(req, message) + self.scheduler.stream_output([req], req.return_logprob) + return True + return False + + def extend(self, reqs: List[Req], is_retracted: bool = False) -> None: """Add a request to the pending queue.""" for req in reqs: - self.add(req) + self.add(req, is_retracted=is_retracted) + + def resume_retracted_reqs(self) -> List[Req]: + # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible + + # allocate memory + resumed_reqs = [] + indices_to_remove = set() + allocatable_tokens = self._allocatable_tokens(count_retracted=False) + + for i, req in enumerate(self.retracted_queue): + if self.req_to_token_pool.available_size() <= 0: + break + + required_tokens_for_request = ( + len(req.origin_input_ids) + + len(req.output_ids) + + self.num_reserved_decode_tokens + ) + if required_tokens_for_request > allocatable_tokens: + break + + resumed_reqs.append(req) + indices_to_remove.add(i) + req.is_retracted = False + self._pre_alloc(req) + allocatable_tokens -= required_tokens_for_request + + # load from cpu, release the cpu copy + req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator) + + self.retracted_queue = [ + entry + for i, entry in enumerate(self.retracted_queue) + if i not in indices_to_remove + ] + + return resumed_reqs def _update_handshake_waiters(self) -> None: if not self.queue: @@ -255,6 +334,8 @@ class DecodePreallocQueue: error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) + else: + raise ValueError(f"Unexpected poll case: {poll}") def pop_preallocated(self) -> List[DecodeRequest]: """Pop the preallocated requests from the pending queue (FIFO).""" @@ -262,8 +343,16 @@ class DecodePreallocQueue: preallocated_reqs = [] indices_to_remove = set() - allocatable_tokens = self._allocatable_tokens() + # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request + # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted. + retractable_tokens = sum( + len(r.origin_input_ids) + len(r.output_ids) + for r in self.scheduler.running_batch.reqs + ) + allocatable_tokens = self._allocatable_tokens( + retractable_tokens=retractable_tokens, count_retracted=True + ) # First, remove all failed requests from the queue for i, decode_req in enumerate(self.queue): if isinstance(decode_req.req.finished_reason, FINISH_ABORT): @@ -272,6 +361,7 @@ class DecodePreallocQueue: ) indices_to_remove.add(i) + # Then, preallocate the remaining requests if possible for i, decode_req in enumerate(self.queue): if i in indices_to_remove: continue @@ -285,10 +375,23 @@ class DecodePreallocQueue: if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: break + # Memory estimation: don't add if the projected memory cannot be met + # TODO: add new_token ratio + origin_input_len = len(decode_req.req.origin_input_ids) required_tokens_for_request = ( - len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens + origin_input_len + self.num_reserved_decode_tokens ) + if ( + max( + required_tokens_for_request, + origin_input_len + + decode_req.req.sampling_params.max_new_tokens + - retractable_tokens, + ) + > allocatable_tokens + ): + break if required_tokens_for_request > allocatable_tokens: break @@ -321,15 +424,35 @@ class DecodePreallocQueue: return preallocated_reqs - def _allocatable_tokens(self) -> int: - allocatable_tokens = ( - self.token_to_kv_pool_allocator.available_size() - - self.num_reserved_decode_tokens + def _allocatable_tokens( + self, retractable_tokens: Optional[int] = None, count_retracted: bool = True + ) -> int: + need_space_for_single_req = ( + max( + [ + x.sampling_params.max_new_tokens + + len(x.origin_input_ids) + - retractable_tokens + for x in self.scheduler.running_batch.reqs + ] + ) + if retractable_tokens is not None + and len(self.scheduler.running_batch.reqs) > 0 + else 0 + ) + + available_size = self.token_to_kv_pool_allocator.available_size() + + allocatable_tokens = available_size - max( + # preserve some space for future decode + self.num_reserved_decode_tokens * ( len(self.scheduler.running_batch.reqs) + len(self.transfer_queue.queue) + len(self.scheduler.waiting_queue) - ) + ), + # make sure each request can finish if reach max_tokens with all other requests retracted + need_space_for_single_req, ) # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration @@ -342,15 +465,27 @@ class DecodePreallocQueue: self.scheduler.last_batch.reqs ) + if count_retracted: + allocatable_tokens -= sum( + [ + len(req.origin_input_ids) + + len(req.output_ids) + + self.num_reserved_decode_tokens + for req in self.retracted_queue + ] + ) return allocatable_tokens def _pre_alloc(self, req: Req) -> torch.Tensor: """Pre-allocate the memory for req_to_token and token_kv_pool""" req_pool_indices = self.req_to_token_pool.alloc(1) - assert req_pool_indices is not None + assert ( + req_pool_indices is not None + ), "req_pool_indices is full! There is a bug in memory estimation." req.req_pool_idx = req_pool_indices[0] + if self.token_to_kv_pool_allocator.page_size == 1: kv_loc = self.token_to_kv_pool_allocator.alloc( len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) @@ -375,7 +510,10 @@ class DecodePreallocQueue: ), extend_num_tokens=num_tokens, ) - assert kv_loc is not None + + assert ( + kv_loc is not None + ), "KV cache is full! There is a bug in memory estimation." self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) @@ -395,6 +533,7 @@ class DecodeTransferQueue: self, gloo_group: ProcessGroup, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, + tp_rank: int, metadata_buffers: MetadataBuffers, scheduler: Scheduler, tree_cache: BasePrefixCache, @@ -402,6 +541,7 @@ class DecodeTransferQueue: self.queue: List[DecodeRequest] = [] self.gloo_group = gloo_group self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator + self.tp_rank = tp_rank self.metadata_buffers = metadata_buffers self.scheduler = scheduler self.tree_cache = tree_cache @@ -412,10 +552,9 @@ class DecodeTransferQueue: def extend(self, decode_reqs: List[DecodeRequest]) -> None: self.queue.extend(decode_reqs) - def pop_transferred(self) -> List[DecodeRequest]: + def pop_transferred(self) -> List[Req]: if not self.queue: return [] - polls = poll_and_all_reduce( [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) @@ -424,7 +563,7 @@ class DecodeTransferQueue: indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if poll == KVPoll.Failed: - error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" try: decode_req.kv_receiver.failure_exception() except Exception as e: @@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin: batch, _ = self._prepare_idle_batch_and_run(None) if batch is None and ( - len(self.disagg_decode_transfer_queue.queue) + len(self.waiting_queue) + + len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_prealloc_queue.queue) == 0 ): @@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin: self.process_batch_result(tmp_batch, tmp_result) if batch is None and ( - len(self.disagg_decode_transfer_queue.queue) + len(self.waiting_queue) + + len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_prealloc_queue.queue) == 0 ): @@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin: return new_batch def process_decode_queue(self: Scheduler): + # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps + resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs() + self.waiting_queue.extend(resumed_reqs) + if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0: + # if there are still retracted requests, we do not allocate new requests + return + req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() self.disagg_decode_transfer_queue.extend(req_conns) alloc_reqs = ( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 3382c9473..94be8b1f6 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -25,6 +25,7 @@ from collections import deque from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional +import numpy as np import torch from sglang.srt.disaggregation.base import BaseKVManager, KVPoll @@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin: self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] .cpu() .numpy() + .astype(np.int64) ) req.start_send_idx = end_idx if last_chunk: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 912018ca9..9ca8700f0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1415,6 +1415,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req = self.reqs[idx] retracted_reqs.append(req) + if server_args.disaggregation_mode == "decode": + req.offload_kv_cache( + self.req_to_token_pool, self.token_to_kv_pool_allocator + ) + if isinstance(self.tree_cache, ChunkCache): # ChunkCache does not have eviction token_indices = self.req_to_token_pool.req_to_token[ @@ -1446,6 +1451,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req.reset_for_retract() + if len(retracted_reqs) == 0: + # Corner case: only one request left + raise ValueError( + "Failed to retract any request. No space left for only one request." + ) + self.filter_batch(keep_indices=sorted_indices) # Reqs in batch are filtered diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0b3a76667..e327f7bae 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -628,6 +628,7 @@ class Scheduler( self.disagg_decode_transfer_queue = DecodeTransferQueue( gloo_group=self.attn_tp_cpu_group, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, + tp_rank=self.tp_rank, metadata_buffers=self.disagg_metadata_buffers, scheduler=self, tree_cache=self.tree_cache, @@ -650,7 +651,11 @@ class Scheduler( gloo_group=self.attn_tp_cpu_group, tp_rank=self.tp_rank, tp_size=self.tp_size, + dp_size=self.server_args.dp_size, + gpu_id=self.gpu_id, bootstrap_port=self.server_args.disaggregation_bootstrap_port, + max_total_num_tokens=self.max_total_num_tokens, + prefill_pp_size=self.server_args.disaggregation_prefill_pp, transfer_backend=self.transfer_backend, ) @@ -1124,14 +1129,14 @@ class Scheduler( else: self.waiting_queue.append(req) - def _extend_requests_to_queue(self, reqs: List[Req]): + def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): if self.disaggregation_mode == DisaggregationMode.PREFILL: self.disagg_prefill_bootstrap_queue.extend( reqs, self.model_config.num_key_value_heads ) elif self.disaggregation_mode == DisaggregationMode.DECODE: # If this is a decode server, we put the request to the decode pending prealloc queue - self.disagg_decode_prealloc_queue.extend(reqs) + self.disagg_decode_prealloc_queue.extend(reqs, is_retracted) else: self.waiting_queue.extend(reqs) @@ -1274,6 +1279,7 @@ class Scheduler( if self.disaggregation_mode == DisaggregationMode.DECODE: msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " + msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += ( f"cuda graph: {can_run_cuda_graph}, " @@ -1575,7 +1581,7 @@ class Scheduler( f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) - self._extend_requests_to_queue(retracted_reqs) + self._extend_requests_to_queue(retracted_reqs, is_retracted=True) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 2f3b7fdb6..8ae7ba6b1 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -234,6 +234,12 @@ class TokenToKVPoolAllocator: self.is_not_in_free_group = True self.free_group = [] + def get_cpu_copy(self, indices): + return self._kvcache.get_cpu_copy(indices) + + def load_cpu_copy(self, kv_cache_cpu, indices): + return self._kvcache.load_cpu_copy(kv_cache_cpu, indices) + class MHATokenToKVPool(KVCache): @@ -265,6 +271,8 @@ class MHATokenToKVPool(KVCache): self.head_dim = head_dim self._create_buffers() + # used for chunked cpu-offloading + self.chunk_size = 8192 self.layer_transfer_counter = None self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if _is_cuda else None @@ -329,6 +337,39 @@ class MHATokenToKVPool(KVCache): ] return kv_data_ptrs, kv_data_lens, kv_item_lens + def get_cpu_copy(self, indices): + torch.cuda.synchronize() + kv_cache_cpu = [] + for layer_id in range(self.layer_num): + kv_cache_cpu.append([]) + for i in range(0, len(indices), self.chunk_size): + chunk_indices = indices[i : i + self.chunk_size] + k_cpu = self.k_buffer[layer_id][chunk_indices].to( + "cpu", non_blocking=True + ) + v_cpu = self.v_buffer[layer_id][chunk_indices].to( + "cpu", non_blocking=True + ) + kv_cache_cpu[-1].append([k_cpu, v_cpu]) + torch.cuda.synchronize() + return kv_cache_cpu + + def load_cpu_copy(self, kv_cache_cpu, indices): + torch.cuda.synchronize() + for layer_id in range(self.layer_num): + for i in range(0, len(indices), self.chunk_size): + chunk_indices = indices[i : i + self.chunk_size] + k_cpu, v_cpu = ( + kv_cache_cpu[layer_id][i // self.chunk_size][0], + kv_cache_cpu[layer_id][i // self.chunk_size][1], + ) + assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices) + k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True) + v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True) + self.k_buffer[layer_id][chunk_indices] = k_chunk + self.v_buffer[layer_id][chunk_indices] = v_chunk + torch.cuda.synchronize() + # Todo: different memory layout def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index a4a85eb36..b325314a2 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -469,5 +469,132 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.20) +class TestDisaggregationSimulatedRetract(CustomTestCase): + @classmethod + def setUpClass(cls): + os.environ["SGLANG_TEST_RETRACT"] = "true" + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce1", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + os.environ.pop("SGLANG_TEST_RETRACT") + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + # wait for 5 seconds + time.sleep(5) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.62) + + if __name__ == "__main__": unittest.main()