Improve the control of streaming and improve the first token latency in streaming (#117)

This commit is contained in:
Lianmin Zheng
2024-01-29 17:05:42 -08:00
committed by GitHub
parent cd6872334e
commit 6f560c761b
12 changed files with 46 additions and 23 deletions

View File

@@ -21,14 +21,17 @@ class FinishReason(Enum):
class Req:
def __init__(self, rid):
def __init__(self, rid, input_text, input_ids):
self.rid = rid
self.input_text = None
self.input_ids = []
self.input_text = input_text
self.input_ids = input_ids
self.output_ids = []
# For vision input
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
@@ -46,7 +49,7 @@ class Req:
self.logprob = None
self.normalized_logprob = None
# for constrained decoding
# For constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.fast_forward_map = None

View File

@@ -40,7 +40,7 @@ class RouterManager:
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
# async sleep for recving the subsequent request, and avoiding cache miss
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished for obj in out_pyobjs])
if has_finished:

View File

@@ -17,8 +17,8 @@ from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
TokenizedGenerateReqInput,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
@@ -194,6 +194,9 @@ class ModelRpcServer(rpyc.Service):
if self.running_batch.is_empty():
self.running_batch = None
break
if self.out_pyobjs and self.running_batch.reqs[0].stream:
break
else:
# check the available size
available_size = (
@@ -208,8 +211,7 @@ class ModelRpcServer(rpyc.Service):
)
if self.running_batch is not None and self.tp_rank == 0:
if self.decode_forward_ct >= 20:
self.decode_forward_ct = 0
if self.decode_forward_ct % 20 == 0:
num_used = self.max_total_num_token - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
@@ -225,11 +227,8 @@ class ModelRpcServer(rpyc.Service):
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
if req.pixel_values is not None:
pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
@@ -240,6 +239,7 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
)
req.image_size = recv_req.image_size
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
@@ -327,9 +327,11 @@ class ModelRpcServer(rpyc.Service):
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
):
# Undo the insertion
delta = self.tree_cache.dec_ref_counter(req.last_node)
available_size += delta
else:
# Add this request to the running batch
self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req)
new_batch_total_tokens += (
@@ -421,7 +423,7 @@ class ModelRpcServer(rpyc.Service):
return
# Update batch tensors
self.decode_forward_ct += 1
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
batch.prepare_for_decode()
# Forward
@@ -454,7 +456,13 @@ class ModelRpcServer(rpyc.Service):
unfinished_indices.append(i)
if req.finished or (
req.stream and self.decode_forward_ct % self.stream_interval == 0
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
)
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)

View File

@@ -7,7 +7,6 @@ from typing import List
import numpy as np
import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
@@ -16,6 +15,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import sglang
logger = logging.getLogger("model_runner")

View File

@@ -18,9 +18,9 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
FlushCacheReq,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams

View File

@@ -158,7 +158,7 @@ class LlavaLlamaForCausalLM(nn.Module):
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError
raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3

View File

@@ -19,7 +19,7 @@ class ServerArgs:
schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0
random_seed: int = 42
stream_interval: int = 2
stream_interval: int = 8
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
@@ -132,7 +132,7 @@ class ServerArgs:
"--stream-interval",
type=int,
default=ServerArgs.stream_interval,
help="The interval in terms of token length for streaming",
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
parser.add_argument(
"--log-level",