Improve the structure of CI (#911)
This commit is contained in:
9
scripts/deprecated/test_curl.sh
Normal file
9
scripts/deprecated/test_curl.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
curl http://localhost:30000/generate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Once upon a time,",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 16,
|
||||
"temperature": 0
|
||||
}
|
||||
}'
|
||||
215
scripts/deprecated/test_flashinfer.py
Normal file
215
scripts/deprecated/test_flashinfer.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
|
||||
flashinfer_prefill_wrapper = None
|
||||
flashinfer_decode_wrapper = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
||||
@pytest.mark.parametrize("kv_len", [54, 97])
|
||||
@pytest.mark.parametrize("qo_len", [37, 17])
|
||||
@pytest.mark.parametrize("num_kv_heads", [4])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32, 4])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_prefill_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
qo_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
||||
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
v_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
o_triton = torch.empty_like(q)
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
|
||||
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
max_len_extend = qo_len
|
||||
|
||||
extend_attention_fwd(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_triton,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
None, # b_start_loc = None
|
||||
b_seq_len,
|
||||
None, # b_seq_len_prefix = None
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
o_redundant = torch.empty_like(q)
|
||||
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
|
||||
b_seq_len_prefix = b_seq_len - b_seq_len_extend
|
||||
|
||||
redundant_attention(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
|
||||
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
flashinfer_prefill_wrapper.end_forward()
|
||||
|
||||
flashinfer_prefill_wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
o = flashinfer_prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 17, 37])
|
||||
@pytest.mark.parametrize("kv_len", [54, 127, 537])
|
||||
@pytest.mark.parametrize("num_kv_heads", [32])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_decode_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
||||
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
o_triton = torch.empty_like(q)
|
||||
req_to_token = (
|
||||
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
|
||||
)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
other_kv_index = 0
|
||||
token_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_triton,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type="float16",
|
||||
)
|
||||
o = flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
||||
|
||||
|
||||
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
|
||||
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
|
||||
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
|
||||
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
||||
69
scripts/deprecated/test_httpserver_classify.py
Normal file
69
scripts/deprecated/test_httpserver_classify.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path /model/llama-classification
|
||||
|
||||
python3 test_httpserver_classify.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
def get_logits(url, prompt):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
)
|
||||
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
||||
|
||||
|
||||
def get_logits_batch(url, prompts):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
logits = np.array(
|
||||
list(
|
||||
ret[i]["meta_info"]["normalized_prompt_logprob"]
|
||||
for i in range(len(prompts))
|
||||
)
|
||||
)
|
||||
return logits
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
# A single request
|
||||
prompt = "This is a test prompt.<|eot_id|>"
|
||||
logits = get_logits(url, prompt)
|
||||
print(f"{logits=}")
|
||||
|
||||
# A batch of requests
|
||||
prompts = [
|
||||
"This is a test prompt.<|eot_id|>",
|
||||
"This is another test prompt.<|eot_id|>",
|
||||
"This is a long long long long test prompt.<|eot_id|>",
|
||||
]
|
||||
logits = get_logits_batch(url, prompts)
|
||||
print(f"{logits=}")
|
||||
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
|
||||
|
||||
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def main(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
task1 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
delay=1,
|
||||
)
|
||||
|
||||
task2 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of the United Kindom is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(task1, task2)
|
||||
print(rets)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
55
scripts/deprecated/test_httpserver_decode.py
Normal file
55
scripts/deprecated/test_httpserver_decode.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode(url)
|
||||
test_decode(url, n=3)
|
||||
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
test_decode(
|
||||
url,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
69
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
69
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode_stream.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode_stream(url, return_logprob, top_logprobs_num):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
|
||||
if return_logprob:
|
||||
assert data["meta_info"]["input_token_logprobs"] is not None
|
||||
assert data["meta_info"]["output_token_logprobs"] is not None
|
||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||
for logprob, token_id, token_text in data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][prev:]:
|
||||
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
||||
prev = len(data["meta_info"]["output_token_logprobs"])
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode_stream(url, False, 0)
|
||||
test_decode_stream(url, True, 0)
|
||||
test_decode_stream(url, True, 3)
|
||||
88
scripts/deprecated/test_httpserver_llava.py
Normal file
88
scripts/deprecated/test_httpserver_llava.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
python3 test_httpserver_llava.py
|
||||
|
||||
Output:
|
||||
The image features a man standing on the back of a yellow taxi cab, holding
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = []
|
||||
for i in range(8):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_concurrent(args))
|
||||
|
||||
test_streaming(args)
|
||||
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is Paris.\nThe capital of the United States is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
138
scripts/deprecated/test_jump_forward.py
Normal file
138
scripts/deprecated/test_jump_forward.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.constrained import build_regex_from_object
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
ip_jump_forward = (
|
||||
r"The google's DNS sever address is "
|
||||
+ IP_REGEX
|
||||
+ r" and "
|
||||
+ IP_REGEX
|
||||
+ r". "
|
||||
+ r"The google's website domain name is "
|
||||
+ r"www\.(\w)+\.(\w)+"
|
||||
+ r"."
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def regex_gen(s):
|
||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||
s += "A: " + sgl.gen(
|
||||
"answer",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=ip_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
json_jump_forward = (
|
||||
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
||||
+ r"""\n\{\n"""
|
||||
+ r""" "name": "[\w\d\s]*",\n"""
|
||||
+ r""" "country": "[\w\d\s]*",\n"""
|
||||
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
||||
+ r""" "population": [-+]?[0-9]+,\n"""
|
||||
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
||||
+ r"""\}\n"""
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_gen(s):
|
||||
s += sgl.gen(
|
||||
"json",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=json_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
class Weapon(str, Enum):
|
||||
sword = "sword"
|
||||
axe = "axe"
|
||||
mace = "mace"
|
||||
spear = "spear"
|
||||
bow = "bow"
|
||||
crossbow = "crossbow"
|
||||
|
||||
|
||||
class Armor(str, Enum):
|
||||
leather = "leather"
|
||||
chainmail = "chainmail"
|
||||
plate = "plate"
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
name: constr(max_length=10)
|
||||
age: int
|
||||
armor: Armor
|
||||
weapon: Weapon
|
||||
strength: int
|
||||
|
||||
|
||||
@sgl.function
|
||||
def character_gen(s):
|
||||
s += "Give me a character description who is a wizard.\n"
|
||||
s += sgl.gen(
|
||||
"character",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=build_regex_from_object(Character),
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
state = regex_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "IP TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = json_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "JSON TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = character_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
|
||||
# ==================== IP TEST ====================
|
||||
# Q: What is the IP address of the Google DNS servers?
|
||||
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
||||
# ==================== JSON TEST ====================
|
||||
# The information about Hogwarts is in the following JSON format.
|
||||
|
||||
# {
|
||||
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
||||
# "country": "Scotland",
|
||||
# "latitude": 55.566667,
|
||||
# "population": 1000,
|
||||
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
||||
# }
|
||||
|
||||
# ==================== CHARACTER TEST ====================
|
||||
# Give me a character description who is a wizard.
|
||||
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|
||||
209
scripts/deprecated/test_openai_server.py
Normal file
209
scripts/deprecated/test_openai_server.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
First run the following command to launch the server.
|
||||
Note that TinyLlama adopts different chat templates in different versions.
|
||||
For v0.4, the chat template is chatml.
|
||||
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
|
||||
--port 30000 --chat-template chatml
|
||||
|
||||
Output example:
|
||||
The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
The capital of Canada is Ottawa.
|
||||
The capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
def test_completion(args, echo, logprobs):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
prompt="The capital of France is",
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
text = response.choices[0].text
|
||||
print(response.choices[0].text)
|
||||
if echo:
|
||||
assert text.startswith("The capital of France is")
|
||||
if logprobs:
|
||||
print(response.choices[0].logprobs.top_logprobs)
|
||||
assert response.choices[0].logprobs
|
||||
if echo:
|
||||
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||
else:
|
||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def test_completion_stream(args, echo, logprobs):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.completions.create(
|
||||
model="default",
|
||||
prompt="The capital of France is",
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
stream=True,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
first = True
|
||||
for r in response:
|
||||
if first:
|
||||
if echo:
|
||||
assert r.choices[0].text.startswith("The capital of France is")
|
||||
first = False
|
||||
if logprobs:
|
||||
print(
|
||||
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
||||
flush=True,
|
||||
)
|
||||
print(r.choices[0].logprobs.top_logprobs)
|
||||
else:
|
||||
print(r.choices[0].text, end="", flush=True)
|
||||
assert r.id
|
||||
assert r.usage.prompt_tokens > 0
|
||||
assert r.usage.completion_tokens > 0
|
||||
assert r.usage.total_tokens > 0
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def test_chat_completion(args):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def test_chat_completion_image(args):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def test_chat_completion_stream(args):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
stream=True,
|
||||
)
|
||||
is_first = True
|
||||
for chunk in response:
|
||||
if is_first:
|
||||
is_first = False
|
||||
assert chunk.choices[0].delta.role == "assistant"
|
||||
continue
|
||||
|
||||
data = chunk.choices[0].delta
|
||||
if not data.content:
|
||||
continue
|
||||
print(data.content, end="", flush=True)
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
def test_regex(args):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
print(json.loads(text))
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
||||
parser.add_argument(
|
||||
"--test-image", action="store_true", help="Enables testing image inputs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
test_completion(args, echo=False, logprobs=False)
|
||||
test_completion(args, echo=True, logprobs=False)
|
||||
test_completion(args, echo=False, logprobs=True)
|
||||
test_completion(args, echo=True, logprobs=True)
|
||||
test_completion(args, echo=False, logprobs=3)
|
||||
test_completion(args, echo=True, logprobs=3)
|
||||
test_completion_stream(args, echo=False, logprobs=False)
|
||||
test_completion_stream(args, echo=True, logprobs=False)
|
||||
test_completion_stream(args, echo=False, logprobs=True)
|
||||
test_completion_stream(args, echo=True, logprobs=True)
|
||||
test_completion_stream(args, echo=False, logprobs=3)
|
||||
test_completion_stream(args, echo=True, logprobs=3)
|
||||
test_chat_completion(args)
|
||||
test_chat_completion_stream(args)
|
||||
test_regex(args)
|
||||
if args.test_image:
|
||||
test_chat_completion_image(args)
|
||||
132
scripts/deprecated/test_robust.py
Normal file
132
scripts/deprecated/test_robust.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
TOKENIZER = None
|
||||
RANDOM_PREFILL_LEN = None
|
||||
RANDOM_DECODE_LEN = None
|
||||
|
||||
|
||||
def gen_prompt(token_num):
|
||||
if RANDOM_PREFILL_LEN:
|
||||
token_num = random.randint(1, token_num)
|
||||
|
||||
cha_set = string.ascii_letters + string.digits
|
||||
ret = "".join(random.choices(cha_set, k=token_num))
|
||||
while len(TOKENIZER(ret).input_ids) < token_num:
|
||||
ret += random.choice(cha_set)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def robust_test_dfs(s, d, args, leaf_states):
|
||||
if d == 0:
|
||||
s += "END"
|
||||
leaf_states.append(s)
|
||||
return
|
||||
|
||||
s += gen_prompt(args.len_prefill)
|
||||
forks = s.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
for fork_s in forks:
|
||||
robust_test_dfs(fork_s, d - 1, args, leaf_states)
|
||||
|
||||
|
||||
def robust_test_bfs(s, args, leaf_states):
|
||||
old_forks = [s]
|
||||
new_forks = []
|
||||
for _ in range(args.depth):
|
||||
for old_fork in old_forks:
|
||||
old_fork += gen_prompt(args.len_prefill)
|
||||
forks = old_fork.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
new_forks.extend(forks)
|
||||
|
||||
old_forks = new_forks
|
||||
new_forks = []
|
||||
|
||||
for old_fork in old_forks:
|
||||
old_fork += "END"
|
||||
leaf_states.append(old_fork)
|
||||
|
||||
|
||||
@sgl.function
|
||||
def robust_test(s, args):
|
||||
leaf_states = []
|
||||
if args.mode == "bfs":
|
||||
robust_test_bfs(s, args, leaf_states)
|
||||
else:
|
||||
robust_test_dfs(s, args.depth, args, leaf_states)
|
||||
return leaf_states
|
||||
|
||||
|
||||
def main(args):
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
arguments = [{"args": args} for _ in range(args.num_req)]
|
||||
|
||||
states = robust_test.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel
|
||||
)
|
||||
|
||||
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
|
||||
for state in states:
|
||||
leaf_states = state.ret_value
|
||||
for leaf_state in leaf_states:
|
||||
assert leaf_state.text()[-3:] == "END"
|
||||
f.write(leaf_state.text()[:-3] + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-req", type=int, default=2)
|
||||
parser.add_argument("--depth", type=int, default=3)
|
||||
parser.add_argument("--num-fork", type=int, default=2)
|
||||
parser.add_argument("--len-prefill", type=int, default=128)
|
||||
parser.add_argument("--len-decode", type=int, default=128)
|
||||
parser.add_argument("--random-prefill-len", action="store_true")
|
||||
parser.add_argument("--random-decode-len", action="store_true")
|
||||
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
|
||||
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
# fmt: on
|
||||
|
||||
RANDOM_PREFILL_LEN = args.random_prefill_len
|
||||
RANDOM_DECODE_LEN = args.random_decode_len
|
||||
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
||||
@@ -1,8 +0,0 @@
|
||||
isort python
|
||||
black python
|
||||
|
||||
isort test
|
||||
black test
|
||||
|
||||
isort benchmark
|
||||
black benchmark
|
||||
@@ -1,6 +0,0 @@
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.3.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
Reference in New Issue
Block a user