Add decode req pool (#6980)
This commit is contained in:
@@ -25,7 +25,7 @@ import os
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -49,6 +49,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
|||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -57,6 +58,67 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.managers.scheduler import Scheduler
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeReqToTokenPool:
|
||||||
|
"""
|
||||||
|
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
|
||||||
|
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
|
||||||
|
|
||||||
|
In ReqToTokenPool, if `--max-running-requests` is 8,
|
||||||
|
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
|
||||||
|
|
||||||
|
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
|
||||||
|
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
device: str,
|
||||||
|
enable_memory_saver: bool,
|
||||||
|
pre_alloc_size: int,
|
||||||
|
):
|
||||||
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=enable_memory_saver
|
||||||
|
)
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.max_context_len = max_context_len
|
||||||
|
self.device = device
|
||||||
|
self.pre_alloc_size = pre_alloc_size
|
||||||
|
with memory_saver_adapter.region():
|
||||||
|
self.req_to_token = torch.zeros(
|
||||||
|
(size + pre_alloc_size, max_context_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.free_slots = list(range(size + pre_alloc_size))
|
||||||
|
|
||||||
|
def write(self, indices, values):
|
||||||
|
self.req_to_token[indices] = values
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
return len(self.free_slots)
|
||||||
|
|
||||||
|
def alloc(self, need_size: int) -> List[int]:
|
||||||
|
if need_size > len(self.free_slots):
|
||||||
|
return None
|
||||||
|
|
||||||
|
select_index = self.free_slots[:need_size]
|
||||||
|
self.free_slots = self.free_slots[need_size:]
|
||||||
|
return select_index
|
||||||
|
|
||||||
|
def free(self, free_index: Union[int, List[int]]):
|
||||||
|
if isinstance(free_index, (int,)):
|
||||||
|
self.free_slots.append(free_index)
|
||||||
|
else:
|
||||||
|
self.free_slots.extend(free_index)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeRequest:
|
class DecodeRequest:
|
||||||
req: Req
|
req: Req
|
||||||
|
|||||||
@@ -916,12 +916,26 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.req_to_token_pool is None:
|
if self.req_to_token_pool is None:
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
if self.server_args.disaggregation_mode == "decode":
|
||||||
size=max_num_reqs,
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||||
max_context_len=self.model_config.context_len + 4,
|
|
||||||
device=self.device,
|
# subscribe memory for pre-allocated requests
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
||||||
)
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||||
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||||
|
size=max_num_reqs,
|
||||||
|
max_context_len=self.model_config.context_len + 4,
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
pre_alloc_size=pre_alloc_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
|
size=max_num_reqs,
|
||||||
|
max_context_len=self.model_config.context_len + 4,
|
||||||
|
device=self.device,
|
||||||
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Draft worker shares req_to_token_pool with the target worker.
|
# Draft worker shares req_to_token_pool with the target worker.
|
||||||
assert self.is_draft_worker
|
assert self.is_draft_worker
|
||||||
|
|||||||
Reference in New Issue
Block a user