Improve the control of streaming and improve the first token latency in streaming (#117)

This commit is contained in:
Lianmin Zheng
2024-01-29 17:05:42 -08:00
committed by GitHub
parent cd6872334e
commit 6f560c761b
12 changed files with 46 additions and 23 deletions

View File

@@ -28,7 +28,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs = []
for i in range(len(prompts)):
req = Req(i)
req = Req(i, None, None)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params
reqs.append(req)

View File

@@ -112,6 +112,7 @@ def test_generate_worker(
prefill_params = (
torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values),
[None],
[offset],
*params,
)

View File

@@ -1,5 +1,8 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo

View File

@@ -1,5 +1,7 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo

View File

@@ -1,5 +1,7 @@
"""
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output:
The image features a man standing on the back of a yellow taxi cab, holding
@@ -64,9 +66,12 @@ def test_streaming(args):
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)