Json Decode && Mutl-Turns (#4)

This commit is contained in:
Liangsheng Yin
2024-01-15 16:49:29 +08:00
committed by GitHub
parent f652494df1
commit 08ab2a1655
27 changed files with 755 additions and 41 deletions

View File

@@ -37,6 +37,7 @@ def gen(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
@@ -60,6 +61,7 @@ def gen(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
)
@@ -74,6 +76,7 @@ def gen_int(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
):
return SglGen(
name,
@@ -84,6 +87,7 @@ def gen_int(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
int,
None,
)
@@ -98,6 +102,7 @@ def gen_string(
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
):
return SglGen(
name,
@@ -108,6 +113,7 @@ def gen_string(
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
str,
None,
)

View File

@@ -4,7 +4,7 @@ import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
try:
import anthropic
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
prompt = s.text_
ret = anthropic.Anthropic().completions.create(
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
prompt = s.text_
generator = anthropic.Anthropic().completions.create(

View File

@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
class BaseBackend:
@@ -48,14 +48,14 @@ class BaseBackend:
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()

View File

@@ -4,7 +4,7 @@ import numpy as np
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
try:
import openai
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:

View File

@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams, SglArgument
from sglang.lang.ir import SglSamplingParams, SglArgument
from sglang.utils import encode_image_base64, find_printable_text, http_request
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
def generate(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {

View File

@@ -7,7 +7,7 @@ from typing import List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SamplingParams
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
self,
s: StreamExecutor,
choices: List[str],
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
):
decision = self.retry_for_expected(
s.text_,
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
s: StreamExecutor,
max_tokens: int,
stop: Union[str, List[str]],
sampling_params: SamplingParams,
sampling_params: SglSamplingParams,
dtype: Optional[str] = None,
):
if dtype is None:

View File

@@ -6,7 +6,7 @@ from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.ir import (
SamplingParams,
SglSamplingParams,
SglArgument,
SglConstantText,
SglExpr,
@@ -140,7 +140,7 @@ class CompiledFunction:
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
kwargs.update(self.function.bind_arguments)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -173,7 +173,7 @@ class CompiledFunction:
backend = backend or global_config.default_backend
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,

View File

@@ -292,7 +292,7 @@ class StreamExecutor:
assert isinstance(other, SglExpr), f"{other}"
if isinstance(other, (SglConstantText, SglArgument)):
if isinstance(other, SglConstantText):
self._execute_fill(other.value)
elif isinstance(other, SglGen):
self._execute_gen(other)
@@ -332,8 +332,6 @@ class StreamExecutor:
def _execute_image(self, expr: SglImage):
path = expr.path
if isinstance(path, SglArgument):
path = path.value
base64_data = encode_image_base64(path)
@@ -419,7 +417,7 @@ class StreamExecutor:
"role": expr.role,
"content": [{"type": "text", "text": new_text}],
}
for (image_path, image_base64_data) in self.cur_images:
for image_path, image_base64_data in self.cur_images:
last_msg["content"].append(
{
"type": "image_url",
@@ -480,6 +478,7 @@ class StreamExecutor:
"top_k",
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"dtype",
"regex",
]:

View File

@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SamplingParams:
class SglSamplingParams:
max_new_tokens: int = 16
stop: Union[str, List[str]] = ()
temperature: float = 1.0
@@ -21,13 +21,14 @@ class SamplingParams:
top_k: int = -1 # -1 means disable
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
regex: Optional[str] = None
def clone(self):
return SamplingParams(
return SglSamplingParams(
self.max_new_tokens,
self.stop,
self.temperature,
@@ -67,6 +68,7 @@ class SamplingParams:
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
}
@@ -98,13 +100,14 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
stream: bool = False,
backend=None,
**kwargs,
):
from sglang.lang.interpreter import run_program
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -112,9 +115,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
def run_batch(
@@ -128,6 +131,7 @@ class SglFunction:
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
@@ -139,7 +143,7 @@ class SglFunction:
return []
assert isinstance(batch_kwargs[0], dict)
default_sampling_para = SamplingParams(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -147,11 +151,9 @@ class SglFunction:
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
)
backend = backend or global_config.default_backend
batch_kwargs = [
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
]
return run_program_batch(
self,
backend,
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
):
super().__init__()
self.name = name
self.sampling_params = SamplingParams(
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
temperature=temperature,
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
top_k=top_k,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
dtype=dtype,
regex=regex,
)

View File

@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except StopTracing:
except (StopTracing, TypeError):
# Some exceptions may not be catched
pass
# Run and cache prefix

View File

@@ -0,0 +1,12 @@
"""
Backend configurations, may vary with different serving platforms.
"""
from dataclasses import dataclass
@dataclass
class BackendConfig:
extend_dependency_time: float = 0.03
GLOBAL_BACKEND_CONFIG = BackendConfig()

View File

@@ -1,6 +1,5 @@
import asyncio
import logging
from typing import List, Tuple
import uvloop
import zmq
@@ -8,6 +7,7 @@ import zmq.asyncio
from sglang.srt.managers.router.model_rpc import ModelRpcClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -28,6 +28,9 @@ class RouterManager:
self.model_client = model_client
self.recv_reqs = []
# Init Some Configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
async def loop_for_forward(self):
while True:
next_step_input = list(self.recv_reqs)
@@ -37,7 +40,12 @@ class RouterManager:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# await for a while to accept input requests
# async sleep for recving the subsequent request, and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:
await asyncio.sleep(self.extend_dependency_time)
await asyncio.sleep(0.001)
async def loop_for_recv_requests(self):

View File

@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_exception_traceback,
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty():
self.running_batch = None
break
else:
# check the available size
available_size = (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
logger.warning(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"
)
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
token_ids = tuple(req.input_ids + req.output_ids)
seq_len = len(token_ids) - 1
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
prefix_len = self.tree_cache.insert(
token_ids[:seq_len], indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)

View File

@@ -18,7 +18,7 @@ class Scheduler:
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":

View File

@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
class SamplingParams:
def __init__(
self,
max_new_tokens: int = 16,
stop: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stop: Optional[Union[str, List[str]]] = None,
max_new_tokens: int = 16,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
dtype: Optional[str] = None,

View File

@@ -24,6 +24,8 @@ class ServerArgs:
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.tp_size > 1:
self.mem_fraction_static = 0.8
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):

View File

@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
return pred
def call_generate_outlines(
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"regex": regex,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
data = {
"text": prompt,

View File

@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
if isinstance(s, str):
pass
elif isinstance(s, ProgramState):
s = s.text().strip()
s = s.text()
else:
s = str(s)