Add decode req pool (#6980)
This commit is contained in:
@@ -25,7 +25,7 @@ import os
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,6 +58,67 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
|
||||
|
||||
class DecodeReqToTokenPool:
|
||||
"""
|
||||
The difference of DecodeReqToTokenPool and ReqToTokenPool is that
|
||||
DecodeReqToTokenPool subscribes memory for pre-allocated requests.
|
||||
|
||||
In ReqToTokenPool, if `--max-running-requests` is 8,
|
||||
#pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.
|
||||
|
||||
In DecodeReqToTokenPool, if `--max-running-requests` is 8,
|
||||
#running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
pre_alloc_size: int,
|
||||
):
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
self.size = size
|
||||
self.max_context_len = max_context_len
|
||||
self.device = device
|
||||
self.pre_alloc_size = pre_alloc_size
|
||||
with memory_saver_adapter.region():
|
||||
self.req_to_token = torch.zeros(
|
||||
(size + pre_alloc_size, max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.free_slots = list(range(size + pre_alloc_size))
|
||||
|
||||
def write(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> List[int]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeRequest:
|
||||
req: Req
|
||||
|
||||
@@ -916,12 +916,26 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.req_to_token_pool is None:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
|
||||
# subscribe memory for pre-allocated requests
|
||||
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
else:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
Reference in New Issue
Block a user