From 6f560c761b2fc2f577682d0cfda62630f37a3bb0 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 29 Jan 2024 17:05:42 -0800 Subject: [PATCH] Improve the control of streaming and improve the first token latency in streaming (#117) --- .../sglang/srt/managers/router/infer_batch.py | 11 +++++--- python/sglang/srt/managers/router/manager.py | 2 +- .../sglang/srt/managers/router/model_rpc.py | 26 ++++++++++++------- .../srt/managers/router/model_runner.py | 3 ++- .../sglang/srt/managers/tokenizer_manager.py | 2 +- python/sglang/srt/models/llava.py | 2 +- python/sglang/srt/server_args.py | 4 +-- test/srt/model/test_llama_extend.py | 2 +- test/srt/model/test_llava_low_api.py | 1 + test/srt/test_httpserver_decode.py | 3 +++ test/srt/test_httpserver_decode_stream.py | 2 ++ test/srt/test_httpserver_llava.py | 11 +++++--- 12 files changed, 46 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 00ada2955..c5aa88615 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -21,14 +21,17 @@ class FinishReason(Enum): class Req: - def __init__(self, rid): + def __init__(self, rid, input_text, input_ids): self.rid = rid - self.input_text = None - self.input_ids = [] + self.input_text = input_text + self.input_ids = input_ids self.output_ids = [] + + # For vision input self.pixel_values = None self.image_size = None self.image_offset = 0 + self.sampling_params = None self.return_logprob = False self.logprob_start_len = 0 @@ -46,7 +49,7 @@ class Req: self.logprob = None self.normalized_logprob = None - # for constrained decoding + # For constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 self.fast_forward_map = None diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index 0732d0fa8..4dc7d1f1c 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -40,7 +40,7 @@ class RouterManager: for obj in out_pyobjs: self.send_to_detokenizer.send_pyobj(obj) - # async sleep for recving the subsequent request, and avoiding cache miss + # async sleep for receiving the subsequent request and avoiding cache miss if len(out_pyobjs) != 0: has_finished = any([obj.finished for obj in out_pyobjs]) if has_finished: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 199a8974b..eb5fc2f43 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( BatchTokenIDOut, - TokenizedGenerateReqInput, FlushCacheReq, + TokenizedGenerateReqInput, ) from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.router.model_runner import ModelRunner @@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service): if self.running_batch.is_empty(): self.running_batch = None break + + if self.out_pyobjs and self.running_batch.reqs[0].stream: + break else: # check the available size available_size = ( @@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service): ) if self.running_batch is not None and self.tp_rank == 0: - if self.decode_forward_ct >= 20: - self.decode_forward_ct = 0 + if self.decode_forward_ct % 20 == 0: num_used = self.max_total_num_token - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service): self, recv_req: TokenizedGenerateReqInput, ): - req = Req(recv_req.rid) - req.input_text = recv_req.input_text - req.input_ids = recv_req.input_ids + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.pixel_values = recv_req.pixel_values - req.image_size = recv_req.image_size if req.pixel_values is not None: pad_value = [ (recv_req.image_hash) % self.model_config.vocab_size, @@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service): req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( req.input_ids, pad_value, req.pixel_values.shape, req.image_size ) + req.image_size = recv_req.image_size req.sampling_params = recv_req.sampling_params req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len @@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service): req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size ): + # Undo the insertion delta = self.tree_cache.dec_ref_counter(req.last_node) available_size += delta else: + # Add this request to the running batch self.token_to_kv_pool.add_refs(req.prefix_indices) can_run_list.append(req) new_batch_total_tokens += ( @@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service): return # Update batch tensors - self.decode_forward_ct += 1 + self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) batch.prepare_for_decode() # Forward @@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service): unfinished_indices.append(i) if req.finished or ( - req.stream and self.decode_forward_ct % self.stream_interval == 0 + ( + req.stream + and ( + self.decode_forward_ct % self.stream_interval == 0 + or len(req.output_ids) == 1 + ) + ) ): output_rids.append(req.rid) output_tokens.append(req.output_ids) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index c85ec534d..7d72c6c70 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -7,7 +7,6 @@ from typing import List import numpy as np import torch -import sglang from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.utils import is_multimodal_model @@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel +import sglang + logger = logging.getLogger("model_runner") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d08b33634..2213858bf 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.managers.io_struct import ( BatchStrOut, + FlushCacheReq, GenerateReqInput, TokenizedGenerateReqInput, - FlushCacheReq, ) from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index cd3e93cbd..efc362f59 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module): num_patch_height, num_patch_width, height, width, -1 ) else: - raise NotImplementedError + raise NotImplementedError() if "unpad" in self.mm_patch_merge_type: image_feature = image_feature.permute( 4, 0, 2, 1, 3 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5fcb6f5c2..17e436d8d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -19,7 +19,7 @@ class ServerArgs: schedule_heuristic: str = "lpm" schedule_conservativeness: float = 1.0 random_seed: int = 42 - stream_interval: int = 2 + stream_interval: int = 8 disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" @@ -132,7 +132,7 @@ class ServerArgs: "--stream-interval", type=int, default=ServerArgs.stream_interval, - help="The interval in terms of token length for streaming", + help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) parser.add_argument( "--log-level", diff --git a/test/srt/model/test_llama_extend.py b/test/srt/model/test_llama_extend.py index b01549878..ae8df9d05 100644 --- a/test/srt/model/test_llama_extend.py +++ b/test/srt/model/test_llama_extend.py @@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size): reqs = [] for i in range(len(prompts)): - req = Req(i) + req = Req(i, None, None) req.input_ids = tokenizer.encode(prompts[i])[:cut_num] req.sampling_params = sampling_params reqs.append(req) diff --git a/test/srt/model/test_llava_low_api.py b/test/srt/model/test_llava_low_api.py index 00cdd622f..f6a77a74d 100644 --- a/test/srt/model/test_llava_low_api.py +++ b/test/srt/model/test_llava_low_api.py @@ -112,6 +112,7 @@ def test_generate_worker( prefill_params = ( torch.tensor(np.array(input_ids)).cuda(), np.array(pixel_values), + [None], [offset], *params, ) diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 21ec0be6a..b26eb030d 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -1,5 +1,8 @@ """ +Usage: python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 +python3 test_httpserver_decode.py + Output: The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index e397f137d..3d63e66cb 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -1,5 +1,7 @@ """ +Usage: python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 +python3 test_httpserver_decode_stream.py Output: The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo diff --git a/test/srt/test_httpserver_llava.py b/test/srt/test_httpserver_llava.py index 042f4229d..25bb79c81 100644 --- a/test/srt/test_httpserver_llava.py +++ b/test/srt/test_httpserver_llava.py @@ -1,5 +1,7 @@ """ +Usage: python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 +python3 test_httpserver_llava.py Output: The image features a man standing on the back of a yellow taxi cab, holding @@ -64,9 +66,12 @@ def test_streaming(args): ) prev = 0 - for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - data = json.loads(chunk.decode()) + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) output = data["text"].strip() print(output[prev:], end="", flush=True) prev = len(output)