Improve the control of streaming and improve the first token latency in streaming (#117)

This commit is contained in:
Lianmin Zheng
2024-01-29 17:05:42 -08:00
committed by GitHub
parent cd6872334e
commit 6f560c761b
12 changed files with 46 additions and 23 deletions

View File

@@ -21,14 +21,17 @@ class FinishReason(Enum):
class Req: class Req:
def __init__(self, rid): def __init__(self, rid, input_text, input_ids):
self.rid = rid self.rid = rid
self.input_text = None self.input_text = input_text
self.input_ids = [] self.input_ids = input_ids
self.output_ids = [] self.output_ids = []
# For vision input
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_size = None
self.image_offset = 0 self.image_offset = 0
self.sampling_params = None self.sampling_params = None
self.return_logprob = False self.return_logprob = False
self.logprob_start_len = 0 self.logprob_start_len = 0
@@ -46,7 +49,7 @@ class Req:
self.logprob = None self.logprob = None
self.normalized_logprob = None self.normalized_logprob = None
# for constrained decoding # For constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.fast_forward_map = None self.fast_forward_map = None

View File

@@ -40,7 +40,7 @@ class RouterManager:
for obj in out_pyobjs: for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj) self.send_to_detokenizer.send_pyobj(obj)
# async sleep for recving the subsequent request, and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs]) has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished: if has_finished:

View File

@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
TokenizedGenerateReqInput,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput,
) )
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.model_runner import ModelRunner
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty(): if self.running_batch.is_empty():
self.running_batch = None self.running_batch = None
break break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else: else:
# check the available size # check the available size
available_size = ( available_size = (
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
) )
if self.running_batch is not None and self.tp_rank == 0: if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20: if self.decode_forward_ct % 20 == 0:
self.decode_forward_ct = 0
num_used = self.max_total_num_token - ( num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
self, self,
recv_req: TokenizedGenerateReqInput, recv_req: TokenizedGenerateReqInput,
): ):
req = Req(recv_req.rid) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
if req.pixel_values is not None: if req.pixel_values is not None:
pad_value = [ pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size, (recv_req.image_hash) % self.model_config.vocab_size,
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size req.input_ids, pad_value, req.pixel_values.shape, req.image_size
) )
req.image_size = recv_req.image_size
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
): ):
# Undo the insertion
delta = self.tree_cache.dec_ref_counter(req.last_node) delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta available_size += delta
else: else:
# Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices) self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req) can_run_list.append(req)
new_batch_total_tokens += ( new_batch_total_tokens += (
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
return return
# Update batch tensors # Update batch tensors
self.decode_forward_ct += 1 self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward # Forward
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
unfinished_indices.append(i) unfinished_indices.append(i)
if req.finished or ( if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0 (
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_tokens.append(req.output_ids) output_tokens.append(req.output_ids)

View File

@@ -7,7 +7,6 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model from sglang.srt.utils import is_multimodal_model
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")

View File

@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
BatchStrOut, BatchStrOut,
FlushCacheReq,
GenerateReqInput, GenerateReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
FlushCacheReq,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams

View File

@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
num_patch_height, num_patch_width, height, width, -1 num_patch_height, num_patch_width, height, width, -1
) )
else: else:
raise NotImplementedError raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type: if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute( image_feature = image_feature.permute(
4, 0, 2, 1, 3 4, 0, 2, 1, 3

View File

@@ -19,7 +19,7 @@ class ServerArgs:
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
random_seed: int = 42 random_seed: int = 42
stream_interval: int = 2 stream_interval: int = 8
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
@@ -132,7 +132,7 @@ class ServerArgs:
"--stream-interval", "--stream-interval",
type=int, type=int,
default=ServerArgs.stream_interval, default=ServerArgs.stream_interval,
help="The interval in terms of token length for streaming", help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
) )
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",

View File

@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs = [] reqs = []
for i in range(len(prompts)): for i in range(len(prompts)):
req = Req(i) req = Req(i, None, None)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num] req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params req.sampling_params = sampling_params
reqs.append(req) reqs.append(req)

View File

@@ -112,6 +112,7 @@ def test_generate_worker(
prefill_params = ( prefill_params = (
torch.tensor(np.array(input_ids)).cuda(), torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values), np.array(pixel_values),
[None],
[offset], [offset],
*params, *params,
) )

View File

@@ -1,5 +1,8 @@
""" """
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output: Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo

View File

@@ -1,5 +1,7 @@
""" """
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output: Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo

View File

@@ -1,5 +1,7 @@
""" """
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output: Output:
The image features a man standing on the back of a yellow taxi cab, holding The image features a man standing on the back of a yellow taxi cab, holding
@@ -64,9 +66,12 @@ def test_streaming(args):
) )
prev = 0 prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False):
if chunk: chunk = chunk.decode("utf-8")
data = json.loads(chunk.decode()) if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip() output = data["text"].strip()
print(output[prev:], end="", flush=True) print(output[prev:], end="", flush=True)
prev = len(output) prev = len(output)