Json Decode && Mutl-Turns (#4)
This commit is contained in:
12
python/sglang/srt/backend_config.py
Normal file
12
python/sglang/srt/backend_config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Backend configurations, may vary with different serving platforms.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
extend_dependency_time: float = 0.03
|
||||
|
||||
|
||||
GLOBAL_BACKEND_CONFIG = BackendConfig()
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -8,6 +7,7 @@ import zmq.asyncio
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -28,6 +28,9 @@ class RouterManager:
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init Some Configs
|
||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
@@ -37,7 +40,12 @@ class RouterManager:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# await for a while to accept input requests
|
||||
# async sleep for recving the subsequent request, and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.extend_dependency_time)
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
|
||||
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
else:
|
||||
# check the available size
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_token:
|
||||
logger.warning(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||
"KV cache pool leak detected!"
|
||||
)
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids[:seq_len], indices.clone()
|
||||
)
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Scheduler:
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
||||
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
|
||||
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_new_tokens: int = 16,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
|
||||
@@ -24,6 +24,8 @@ class ServerArgs:
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.tp_size > 1:
|
||||
self.mem_fraction_static = 0.8
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
Reference in New Issue
Block a user