diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index e206450b6..12b4408fd 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -25,7 +25,7 @@ import os from collections import deque from dataclasses import dataclass from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np import torch @@ -49,6 +49,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter logger = logging.getLogger(__name__) @@ -57,6 +58,67 @@ if TYPE_CHECKING: from sglang.srt.managers.scheduler import Scheduler +class DecodeReqToTokenPool: + """ + The difference of DecodeReqToTokenPool and ReqToTokenPool is that + DecodeReqToTokenPool subscribes memory for pre-allocated requests. + + In ReqToTokenPool, if `--max-running-requests` is 8, + #pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests. + + In DecodeReqToTokenPool, if `--max-running-requests` is 8, + #running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill. + """ + + def __init__( + self, + size: int, + max_context_len: int, + device: str, + enable_memory_saver: bool, + pre_alloc_size: int, + ): + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + self.size = size + self.max_context_len = max_context_len + self.device = device + self.pre_alloc_size = pre_alloc_size + with memory_saver_adapter.region(): + self.req_to_token = torch.zeros( + (size + pre_alloc_size, max_context_len), + dtype=torch.int32, + device=device, + ) + + self.free_slots = list(range(size + pre_alloc_size)) + + def write(self, indices, values): + self.req_to_token[indices] = values + + def available_size(self): + return len(self.free_slots) + + def alloc(self, need_size: int) -> List[int]: + if need_size > len(self.free_slots): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + return select_index + + def free(self, free_index: Union[int, List[int]]): + if isinstance(free_index, (int,)): + self.free_slots.append(free_index) + else: + self.free_slots.extend(free_index) + + def clear(self): + self.free_slots = list(range(self.size + self.pre_alloc_size)) + + @dataclass class DecodeRequest: req: Req diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8700da9c8..995dedd02 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -916,12 +916,26 @@ class ModelRunner: ) if self.req_to_token_pool is None: - self.req_to_token_pool = ReqToTokenPool( - size=max_num_reqs, - max_context_len=self.model_config.context_len + 4, - device=self.device, - enable_memory_saver=self.server_args.enable_memory_saver, - ) + if self.server_args.disaggregation_mode == "decode": + from sglang.srt.disaggregation.decode import DecodeReqToTokenPool + + # subscribe memory for pre-allocated requests + # if max_num_reqs <= 32, we pre-allocate 2x requests + pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 + self.req_to_token_pool = DecodeReqToTokenPool( + size=max_num_reqs, + max_context_len=self.model_config.context_len + 4, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + pre_alloc_size=pre_alloc_size, + ) + else: + self.req_to_token_pool = ReqToTokenPool( + size=max_num_reqs, + max_context_len=self.model_config.context_len + 4, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + ) else: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker