Json Decode && Mutl-Turns (#4)

This commit is contained in:
Liangsheng Yin
2024-01-15 16:49:29 +08:00
committed by GitHub
parent f652494df1
commit 08ab2a1655
27 changed files with 755 additions and 41 deletions

View File

@@ -0,0 +1,12 @@
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
@dataclass
class BackendConfig:
extend_dependency_time: float = 0.03
GLOBAL_BACKEND_CONFIG = BackendConfig()

View File

@@ -1,6 +1,5 @@
import asyncio
import logging
from typing import List, Tuple
import uvloop
import zmq
@@ -8,6 +7,7 @@ import zmq.asyncio
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -28,6 +28,9 @@ class RouterManager:
self.model_client = model_client
self.recv_reqs = []
# Init Some Configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
async def loop_for_forward(self):
while True:
next_step_input = list(self.recv_reqs)
@@ -37,7 +40,12 @@ class RouterManager:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# await for a while to accept input requests
# async sleep for recving the subsequent request, and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.extend_dependency_time)
await asyncio.sleep(0.001)
async def loop_for_recv_requests(self):

View File

@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty():
self.running_batch = None
break
else:
# check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
logger.warning(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
)
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)

View File

@@ -18,7 +18,7 @@ class Scheduler:
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":

View File

@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
class SamplingParams:
def __init__(
self,
max_new_tokens: int = 16,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stop: Optional[Union[str, List[str]]] = None,
max_new_tokens: int = 16,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
dtype: Optional[str] = None,

View File

@@ -24,6 +24,8 @@ class ServerArgs:
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.tp_size > 1:
self.mem_fraction_static = 0.8
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):