Json Decode && Mutl-Turns (#4)
This commit is contained in:
@@ -37,6 +37,7 @@ def gen(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
@@ -60,6 +61,7 @@ def gen(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
@@ -74,6 +76,7 @@ def gen_int(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -84,6 +87,7 @@ def gen_int(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
@@ -98,6 +102,7 @@ def gen_string(
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -108,6 +113,7 @@ def gen_string(
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
ret = anthropic.Anthropic().completions.create(
|
||||
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
prompt = s.text_
|
||||
generator = anthropic.Anthropic().completions.create(
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
@@ -48,14 +48,14 @@ class BaseBackend:
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
|
||||
@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams, SglArgument
|
||||
from sglang.lang.ir import SglSamplingParams, SglArgument
|
||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import List, Optional, Union
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SamplingParams
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
decision = self.retry_for_expected(
|
||||
s.text_,
|
||||
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: SglSamplingParams,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Union
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||
from sglang.lang.ir import (
|
||||
SamplingParams,
|
||||
SglSamplingParams,
|
||||
SglArgument,
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
@@ -140,7 +140,7 @@ class CompiledFunction:
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
kwargs.update(self.function.bind_arguments)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -173,7 +173,7 @@ class CompiledFunction:
|
||||
|
||||
backend = backend or global_config.default_backend
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -292,7 +292,7 @@ class StreamExecutor:
|
||||
|
||||
assert isinstance(other, SglExpr), f"{other}"
|
||||
|
||||
if isinstance(other, (SglConstantText, SglArgument)):
|
||||
if isinstance(other, SglConstantText):
|
||||
self._execute_fill(other.value)
|
||||
elif isinstance(other, SglGen):
|
||||
self._execute_gen(other)
|
||||
@@ -332,8 +332,6 @@ class StreamExecutor:
|
||||
|
||||
def _execute_image(self, expr: SglImage):
|
||||
path = expr.path
|
||||
if isinstance(path, SglArgument):
|
||||
path = path.value
|
||||
|
||||
base64_data = encode_image_base64(path)
|
||||
|
||||
@@ -419,7 +417,7 @@ class StreamExecutor:
|
||||
"role": expr.role,
|
||||
"content": [{"type": "text", "text": new_text}],
|
||||
}
|
||||
for (image_path, image_base64_data) in self.cur_images:
|
||||
for image_path, image_base64_data in self.cur_images:
|
||||
last_msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
@@ -480,6 +478,7 @@ class StreamExecutor:
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"ignore_eos",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
|
||||
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingParams:
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 16
|
||||
stop: Union[str, List[str]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -21,13 +21,14 @@ class SamplingParams:
|
||||
top_k: int = -1 # -1 means disable
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
regex: Optional[str] = None
|
||||
|
||||
def clone(self):
|
||||
return SamplingParams(
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.temperature,
|
||||
@@ -67,6 +68,7 @@ class SamplingParams:
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
}
|
||||
|
||||
@@ -98,13 +100,14 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
):
|
||||
from sglang.lang.interpreter import run_program
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -112,9 +115,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
|
||||
def run_batch(
|
||||
@@ -128,6 +131,7 @@ class SglFunction:
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
@@ -139,7 +143,7 @@ class SglFunction:
|
||||
return []
|
||||
assert isinstance(batch_kwargs[0], dict)
|
||||
|
||||
default_sampling_para = SamplingParams(
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -147,11 +151,9 @@ class SglFunction:
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
batch_kwargs = [
|
||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
||||
]
|
||||
return run_program_batch(
|
||||
self,
|
||||
backend,
|
||||
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SamplingParams(
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
|
||||
top_k=top_k,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
|
||||
try:
|
||||
with TracingScope(tracer):
|
||||
tracer.ret_value = program.func(tracer, **arguments)
|
||||
except StopTracing:
|
||||
except (StopTracing, TypeError):
|
||||
# Some exceptions may not be catched
|
||||
pass
|
||||
|
||||
# Run and cache prefix
|
||||
|
||||
12
python/sglang/srt/backend_config.py
Normal file
12
python/sglang/srt/backend_config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Backend configurations, may vary with different serving platforms.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
extend_dependency_time: float = 0.03
|
||||
|
||||
|
||||
GLOBAL_BACKEND_CONFIG = BackendConfig()
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -8,6 +7,7 @@ import zmq.asyncio
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -28,6 +28,9 @@ class RouterManager:
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init Some Configs
|
||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
next_step_input = list(self.recv_reqs)
|
||||
@@ -37,7 +40,12 @@ class RouterManager:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# await for a while to accept input requests
|
||||
# async sleep for recving the subsequent request, and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.extend_dependency_time)
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
|
||||
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
else:
|
||||
# check the available size
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_token:
|
||||
logger.warning(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||
"KV cache pool leak detected!"
|
||||
)
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids[:seq_len], indices.clone()
|
||||
)
|
||||
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Scheduler:
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def new_token_estimation_ratio(self):
|
||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
||||
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
|
||||
|
||||
def get_priority_queue(self, forward_queue):
|
||||
if self.schedule_heuristic == "lpm":
|
||||
|
||||
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
max_new_tokens: int = 16,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_new_tokens: int = 16,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
dtype: Optional[str] = None,
|
||||
|
||||
@@ -24,6 +24,8 @@ class ServerArgs:
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
if self.tp_size > 1:
|
||||
self.mem_fraction_static = 0.8
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_outlines(
|
||||
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
|
||||
):
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"regex": regex,
|
||||
"n": n,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
if n == 1:
|
||||
pred = res.json()["text"][0][len(prompt) :]
|
||||
else:
|
||||
pred = [x[len(prompt) :] for x in res.json()["text"]]
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||
data = {
|
||||
"text": prompt,
|
||||
|
||||
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
|
||||
if isinstance(s, str):
|
||||
pass
|
||||
elif isinstance(s, ProgramState):
|
||||
s = s.text().strip()
|
||||
s = s.text()
|
||||
else:
|
||||
s = str(s)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user