Improve the control of streaming and improve the first token latency in streaming (#117)
This commit is contained in:
@@ -21,14 +21,17 @@ class FinishReason(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
def __init__(self, rid):
|
def __init__(self, rid, input_text, input_ids):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.input_text = None
|
self.input_text = input_text
|
||||||
self.input_ids = []
|
self.input_ids = input_ids
|
||||||
self.output_ids = []
|
self.output_ids = []
|
||||||
|
|
||||||
|
# For vision input
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
self.image_offset = 0
|
self.image_offset = 0
|
||||||
|
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
@@ -46,7 +49,7 @@ class Req:
|
|||||||
self.logprob = None
|
self.logprob = None
|
||||||
self.normalized_logprob = None
|
self.normalized_logprob = None
|
||||||
|
|
||||||
# for constrained decoding
|
# For constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state = 0
|
||||||
self.fast_forward_map = None
|
self.fast_forward_map = None
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class RouterManager:
|
|||||||
for obj in out_pyobjs:
|
for obj in out_pyobjs:
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|
||||||
# async sleep for recving the subsequent request, and avoiding cache miss
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||||
if has_finished:
|
if has_finished:
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
|
|||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
TokenizedGenerateReqInput,
|
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||||
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
# check the available size
|
# check the available size
|
||||||
available_size = (
|
available_size = (
|
||||||
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.running_batch is not None and self.tp_rank == 0:
|
if self.running_batch is not None and self.tp_rank == 0:
|
||||||
if self.decode_forward_ct >= 20:
|
if self.decode_forward_ct % 20 == 0:
|
||||||
self.decode_forward_ct = 0
|
|
||||||
num_used = self.max_total_num_token - (
|
num_used = self.max_total_num_token - (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
req = Req(recv_req.rid)
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
req.input_text = recv_req.input_text
|
|
||||||
req.input_ids = recv_req.input_ids
|
|
||||||
req.pixel_values = recv_req.pixel_values
|
req.pixel_values = recv_req.pixel_values
|
||||||
req.image_size = recv_req.image_size
|
|
||||||
if req.pixel_values is not None:
|
if req.pixel_values is not None:
|
||||||
pad_value = [
|
pad_value = [
|
||||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||||
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||||
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
||||||
)
|
)
|
||||||
|
req.image_size = recv_req.image_size
|
||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
):
|
):
|
||||||
|
# Undo the insertion
|
||||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
available_size += delta
|
available_size += delta
|
||||||
else:
|
else:
|
||||||
|
# Add this request to the running batch
|
||||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||||
can_run_list.append(req)
|
can_run_list.append(req)
|
||||||
new_batch_total_tokens += (
|
new_batch_total_tokens += (
|
||||||
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
self.decode_forward_ct += 1
|
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
unfinished_indices.append(i)
|
unfinished_indices.append(i)
|
||||||
|
|
||||||
if req.finished or (
|
if req.finished or (
|
||||||
req.stream and self.decode_forward_ct % self.stream_interval == 0
|
(
|
||||||
|
req.stream
|
||||||
|
and (
|
||||||
|
self.decode_forward_ct % self.stream_interval == 0
|
||||||
|
or len(req.output_ids) == 1
|
||||||
|
)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
output_tokens.append(req.output_ids)
|
output_tokens.append(req.output_ids)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import sglang
|
|
||||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.utils import is_multimodal_model
|
from sglang.srt.utils import is_multimodal_model
|
||||||
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|||||||
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
||||||
|
|
||||||
|
import sglang
|
||||||
|
|
||||||
logger = logging.getLogger("model_runner")
|
logger = logging.getLogger("model_runner")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
FlushCacheReq,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
|
|||||||
num_patch_height, num_patch_width, height, width, -1
|
num_patch_height, num_patch_width, height, width, -1
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError()
|
||||||
if "unpad" in self.mm_patch_merge_type:
|
if "unpad" in self.mm_patch_merge_type:
|
||||||
image_feature = image_feature.permute(
|
image_feature = image_feature.permute(
|
||||||
4, 0, 2, 1, 3
|
4, 0, 2, 1, 3
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ServerArgs:
|
|||||||
schedule_heuristic: str = "lpm"
|
schedule_heuristic: str = "lpm"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
random_seed: int = 42
|
random_seed: int = 42
|
||||||
stream_interval: int = 2
|
stream_interval: int = 8
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
log_stats_interval: int = 10
|
log_stats_interval: int = 10
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
@@ -132,7 +132,7 @@ class ServerArgs:
|
|||||||
"--stream-interval",
|
"--stream-interval",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.stream_interval,
|
default=ServerArgs.stream_interval,
|
||||||
help="The interval in terms of token length for streaming",
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
|
|||||||
|
|
||||||
reqs = []
|
reqs = []
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
req = Req(i)
|
req = Req(i, None, None)
|
||||||
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||||
req.sampling_params = sampling_params
|
req.sampling_params = sampling_params
|
||||||
reqs.append(req)
|
reqs.append(req)
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ def test_generate_worker(
|
|||||||
prefill_params = (
|
prefill_params = (
|
||||||
torch.tensor(np.array(input_ids)).cuda(),
|
torch.tensor(np.array(input_ids)).cuda(),
|
||||||
np.array(pixel_values),
|
np.array(pixel_values),
|
||||||
|
[None],
|
||||||
[offset],
|
[offset],
|
||||||
*params,
|
*params,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
|
Usage:
|
||||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||||
|
python3 test_httpserver_decode.py
|
||||||
|
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
|
Usage:
|
||||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||||
|
python3 test_httpserver_decode_stream.py
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
|
Usage:
|
||||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||||
|
python3 test_httpserver_llava.py
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
The image features a man standing on the back of a yellow taxi cab, holding
|
The image features a man standing on the back of a yellow taxi cab, holding
|
||||||
@@ -64,9 +66,12 @@ def test_streaming(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
prev = 0
|
prev = 0
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
if chunk:
|
chunk = chunk.decode("utf-8")
|
||||||
data = json.loads(chunk.decode())
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
print(output[prev:], end="", flush=True)
|
print(output[prev:], end="", flush=True)
|
||||||
prev = len(output)
|
prev = len(output)
|
||||||
|
|||||||
Reference in New Issue
Block a user