Improve the control of streaming and improve the first token latency in streaming (#117)
This commit is contained in:
@@ -21,14 +21,17 @@ class FinishReason(Enum):
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, rid):
|
||||
def __init__(self, rid, input_text, input_ids):
|
||||
self.rid = rid
|
||||
self.input_text = None
|
||||
self.input_ids = []
|
||||
self.input_text = input_text
|
||||
self.input_ids = input_ids
|
||||
self.output_ids = []
|
||||
|
||||
# For vision input
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
|
||||
self.sampling_params = None
|
||||
self.return_logprob = False
|
||||
self.logprob_start_len = 0
|
||||
@@ -46,7 +49,7 @@ class Req:
|
||||
self.logprob = None
|
||||
self.normalized_logprob = None
|
||||
|
||||
# for constrained decoding
|
||||
# For constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = 0
|
||||
self.fast_forward_map = None
|
||||
|
||||
@@ -40,7 +40,7 @@ class RouterManager:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
# async sleep for recving the subsequent request, and avoiding cache miss
|
||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
|
||||
@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
||||
break
|
||||
else:
|
||||
# check the available size
|
||||
available_size = (
|
||||
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
)
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
if self.decode_forward_ct >= 20:
|
||||
self.decode_forward_ct = 0
|
||||
if self.decode_forward_ct % 20 == 0:
|
||||
num_used = self.max_total_num_token - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid)
|
||||
req.input_text = recv_req.input_text
|
||||
req.input_ids = recv_req.input_ids
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
req.image_size = recv_req.image_size
|
||||
if req.pixel_values is not None:
|
||||
pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
||||
)
|
||||
req.image_size = recv_req.image_size
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
# Undo the insertion
|
||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||
available_size += delta
|
||||
else:
|
||||
# Add this request to the running batch
|
||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct += 1
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward
|
||||
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished or (
|
||||
req.stream and self.decode_forward_ct % self.stream_interval == 0
|
||||
(
|
||||
req.stream
|
||||
and (
|
||||
self.decode_forward_ct % self.stream_interval == 0
|
||||
or len(req.output_ids) == 1
|
||||
)
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_tokens.append(req.output_ids)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import sglang
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||
|
||||
import sglang
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
|
||||
|
||||
@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchStrOut,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
FlushCacheReq,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
if "unpad" in self.mm_patch_merge_type:
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
|
||||
@@ -19,7 +19,7 @@ class ServerArgs:
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
random_seed: int = 42
|
||||
stream_interval: int = 2
|
||||
stream_interval: int = 8
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
@@ -132,7 +132,7 @@ class ServerArgs:
|
||||
"--stream-interval",
|
||||
type=int,
|
||||
default=ServerArgs.stream_interval,
|
||||
help="The interval in terms of token length for streaming",
|
||||
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
req = Req(i)
|
||||
req = Req(i, None, None)
|
||||
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||
req.sampling_params = sampling_params
|
||||
reqs.append(req)
|
||||
|
||||
@@ -112,6 +112,7 @@ def test_generate_worker(
|
||||
prefill_params = (
|
||||
torch.tensor(np.array(input_ids)).cuda(),
|
||||
np.array(pixel_values),
|
||||
[None],
|
||||
[offset],
|
||||
*params,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode.py
|
||||
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode_stream.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
python3 test_httpserver_llava.py
|
||||
|
||||
Output:
|
||||
The image features a man standing on the back of a yellow taxi cab, holding
|
||||
@@ -64,9 +66,12 @@ def test_streaming(args):
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
|
||||
Reference in New Issue
Block a user