adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
38
scripts/deprecated/convert_yi_vl.py
Normal file
38
scripts/deprecated/convert_yi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Convert Yi-VL config into a format usable with SGLang
|
||||
|
||||
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
|
||||
def add_image_token(model_path: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
|
||||
|
||||
print(tokenizer)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
|
||||
def edit_model_config(model_path):
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
setattr(config, "architectures", ["YiVLForCausalLM"])
|
||||
setattr(config, "image_token_index", 64002)
|
||||
|
||||
print(config)
|
||||
config.save_pretrained(model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-path", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
add_image_token(args.model_path)
|
||||
edit_model_config(args.model_path)
|
||||
13
scripts/deprecated/convert_yi_vl.sh
Normal file
13
scripts/deprecated/convert_yi_vl.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
# For 34B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-34B
|
||||
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B
|
||||
|
||||
# For 6B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-6B
|
||||
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B
|
||||
9
scripts/deprecated/test_curl.sh
Normal file
9
scripts/deprecated/test_curl.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
curl http://localhost:30000/generate \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"text": "Once upon a time,",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 64,
|
||||
"temperature": 0
|
||||
}
|
||||
}'
|
||||
217
scripts/deprecated/test_flashinfer.py
Normal file
217
scripts/deprecated/test_flashinfer.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import pytest
|
||||
import torch
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.attention.triton_ops.extend_attention import (
|
||||
extend_attention_fwd,
|
||||
redundant_attention,
|
||||
)
|
||||
from sglang.srt.utils import should_use_tensor_core
|
||||
|
||||
flashinfer_prefill_wrapper = None
|
||||
flashinfer_decode_wrapper = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
||||
@pytest.mark.parametrize("kv_len", [54, 97])
|
||||
@pytest.mark.parametrize("qo_len", [37, 17])
|
||||
@pytest.mark.parametrize("num_kv_heads", [4])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32, 4])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_prefill_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
qo_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
||||
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
v_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
o_triton = torch.empty_like(q)
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
|
||||
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
max_len_extend = qo_len
|
||||
|
||||
extend_attention_fwd(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_triton,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
None, # b_start_loc = None
|
||||
b_seq_len,
|
||||
None, # b_seq_len_prefix = None
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
o_redundant = torch.empty_like(q)
|
||||
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
|
||||
b_seq_len_prefix = b_seq_len - b_seq_len_extend
|
||||
|
||||
redundant_attention(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_redundant,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
max_len_in_batch,
|
||||
)
|
||||
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
|
||||
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
flashinfer_prefill_wrapper.end_forward()
|
||||
|
||||
flashinfer_prefill_wrapper.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
)
|
||||
o = flashinfer_prefill_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 17, 37])
|
||||
@pytest.mark.parametrize("kv_len", [54, 127, 537])
|
||||
@pytest.mark.parametrize("num_kv_heads", [32])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_decode_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
||||
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
||||
init_flashinfer(num_qo_heads, num_kv_heads)
|
||||
|
||||
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
o_triton = torch.empty_like(q)
|
||||
req_to_token = (
|
||||
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
|
||||
)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
other_kv_index = 0
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_triton,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
flashinfer_decode_wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type="float16",
|
||||
)
|
||||
o = flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
||||
|
||||
|
||||
def init_flashinfer(num_attention_heads, num_kv_heads):
|
||||
use_tensor_cores = should_use_tensor_core(
|
||||
torch.half, num_attention_heads, num_kv_heads
|
||||
)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
||||
|
||||
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
|
||||
|
||||
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD"
|
||||
)
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
|
||||
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
|
||||
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
||||
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
56
scripts/deprecated/test_httpserver_concurrent.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
|
||||
|
||||
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def main(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
task1 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
delay=1,
|
||||
)
|
||||
|
||||
task2 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of the United Kindom is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(task1, task2)
|
||||
print(rets)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
55
scripts/deprecated/test_httpserver_decode.py
Normal file
55
scripts/deprecated/test_httpserver_decode.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode(url)
|
||||
test_decode(url, n=3)
|
||||
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
test_decode(
|
||||
url,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
68
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
68
scripts/deprecated/test_httpserver_decode_stream.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
python3 test_httpserver_decode_stream.py
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_decode_stream(url, return_logprob, top_logprobs_num):
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
|
||||
if return_logprob:
|
||||
assert data["meta_info"]["input_token_logprobs"] is not None
|
||||
assert data["meta_info"]["output_token_logprobs"] is not None
|
||||
for logprob, token_id, token_text in data["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][prev:]:
|
||||
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
||||
prev = len(data["meta_info"]["output_token_logprobs"])
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
test_decode_stream(url, False, 0)
|
||||
test_decode_stream(url, True, 0)
|
||||
test_decode_stream(url, True, 3)
|
||||
88
scripts/deprecated/test_httpserver_llava.py
Normal file
88
scripts/deprecated/test_httpserver_llava.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
python3 test_httpserver_llava.py
|
||||
|
||||
Output:
|
||||
The image features a man standing on the back of a yellow taxi cab, holding
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = []
|
||||
for i in range(8):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "example_image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_concurrent(args))
|
||||
|
||||
test_streaming(args)
|
||||
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
42
scripts/deprecated/test_httpserver_reuse.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is Paris.\nThe capital of the United States is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
138
scripts/deprecated/test_jump_forward.py
Normal file
138
scripts/deprecated/test_jump_forward.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import argparse
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
ip_jump_forward = (
|
||||
r"The google's DNS sever address is "
|
||||
+ IP_REGEX
|
||||
+ r" and "
|
||||
+ IP_REGEX
|
||||
+ r". "
|
||||
+ r"The google's website domain name is "
|
||||
+ r"www\.(\w)+\.(\w)+"
|
||||
+ r"."
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def regex_gen(s):
|
||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||
s += "A: " + sgl.gen(
|
||||
"answer",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=ip_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
json_jump_forward = (
|
||||
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
||||
+ r"""\n\{\n"""
|
||||
+ r""" "name": "[\w\d\s]*",\n"""
|
||||
+ r""" "country": "[\w\d\s]*",\n"""
|
||||
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
||||
+ r""" "population": [-+]?[0-9]+,\n"""
|
||||
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
||||
+ r"""\}\n"""
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
def json_gen(s):
|
||||
s += sgl.gen(
|
||||
"json",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=json_jump_forward,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
class Weapon(str, Enum):
|
||||
sword = "sword"
|
||||
axe = "axe"
|
||||
mace = "mace"
|
||||
spear = "spear"
|
||||
bow = "bow"
|
||||
crossbow = "crossbow"
|
||||
|
||||
|
||||
class Armor(str, Enum):
|
||||
leather = "leather"
|
||||
chainmail = "chainmail"
|
||||
plate = "plate"
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
name: constr(max_length=10)
|
||||
age: int
|
||||
armor: Armor
|
||||
weapon: Weapon
|
||||
strength: int
|
||||
|
||||
|
||||
@sgl.function
|
||||
def character_gen(s):
|
||||
s += "Give me a character description who is a wizard.\n"
|
||||
s += sgl.gen(
|
||||
"character",
|
||||
max_tokens=128,
|
||||
temperature=0,
|
||||
regex=build_regex_from_object(Character),
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
sgl.set_default_backend(backend)
|
||||
|
||||
state = regex_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "IP TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = json_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "JSON TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
state = character_gen.run(temperature=0)
|
||||
|
||||
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
||||
print(state.text())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
|
||||
# ==================== IP TEST ====================
|
||||
# Q: What is the IP address of the Google DNS servers?
|
||||
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
||||
# ==================== JSON TEST ====================
|
||||
# The information about Hogwarts is in the following JSON format.
|
||||
|
||||
# {
|
||||
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
||||
# "country": "Scotland",
|
||||
# "latitude": 55.566667,
|
||||
# "population": 1000,
|
||||
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
||||
# }
|
||||
|
||||
# ==================== CHARACTER TEST ====================
|
||||
# Give me a character description who is a wizard.
|
||||
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|
||||
132
scripts/deprecated/test_robust.py
Normal file
132
scripts/deprecated/test_robust.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import argparse
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
TOKENIZER = None
|
||||
RANDOM_PREFILL_LEN = None
|
||||
RANDOM_DECODE_LEN = None
|
||||
|
||||
|
||||
def gen_prompt(token_num):
|
||||
if RANDOM_PREFILL_LEN:
|
||||
token_num = random.randint(1, token_num)
|
||||
|
||||
cha_set = string.ascii_letters + string.digits
|
||||
ret = "".join(random.choices(cha_set, k=token_num))
|
||||
while len(TOKENIZER(ret).input_ids) < token_num:
|
||||
ret += random.choice(cha_set)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def robust_test_dfs(s, d, args, leaf_states):
|
||||
if d == 0:
|
||||
s += "END"
|
||||
leaf_states.append(s)
|
||||
return
|
||||
|
||||
s += gen_prompt(args.len_prefill)
|
||||
forks = s.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
for fork_s in forks:
|
||||
robust_test_dfs(fork_s, d - 1, args, leaf_states)
|
||||
|
||||
|
||||
def robust_test_bfs(s, args, leaf_states):
|
||||
old_forks = [s]
|
||||
new_forks = []
|
||||
for _ in range(args.depth):
|
||||
for old_fork in old_forks:
|
||||
old_fork += gen_prompt(args.len_prefill)
|
||||
forks = old_fork.fork(args.num_fork)
|
||||
for fork_s in forks:
|
||||
fork_s += gen_prompt(args.len_prefill)
|
||||
new_tokens = (
|
||||
args.len_decode
|
||||
if not RANDOM_DECODE_LEN
|
||||
else random.randint(1, args.len_decode)
|
||||
)
|
||||
fork_s += sgl.gen(
|
||||
max_tokens=new_tokens,
|
||||
ignore_eos=True,
|
||||
)
|
||||
new_forks.extend(forks)
|
||||
|
||||
old_forks = new_forks
|
||||
new_forks = []
|
||||
|
||||
for old_fork in old_forks:
|
||||
old_fork += "END"
|
||||
leaf_states.append(old_fork)
|
||||
|
||||
|
||||
@sgl.function
|
||||
def robust_test(s, args):
|
||||
leaf_states = []
|
||||
if args.mode == "bfs":
|
||||
robust_test_bfs(s, args, leaf_states)
|
||||
else:
|
||||
robust_test_dfs(s, args.depth, args, leaf_states)
|
||||
return leaf_states
|
||||
|
||||
|
||||
def main(args):
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
arguments = [{"args": args} for _ in range(args.num_req)]
|
||||
|
||||
states = robust_test.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel
|
||||
)
|
||||
|
||||
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
|
||||
for state in states:
|
||||
leaf_states = state.ret_value
|
||||
for leaf_state in leaf_states:
|
||||
assert leaf_state.text()[-3:] == "END"
|
||||
f.write(leaf_state.text()[:-3] + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-req", type=int, default=2)
|
||||
parser.add_argument("--depth", type=int, default=3)
|
||||
parser.add_argument("--num-fork", type=int, default=2)
|
||||
parser.add_argument("--len-prefill", type=int, default=128)
|
||||
parser.add_argument("--len-decode", type=int, default=128)
|
||||
parser.add_argument("--random-prefill-len", action="store_true")
|
||||
parser.add_argument("--random-decode-len", action="store_true")
|
||||
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
|
||||
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
# fmt: on
|
||||
|
||||
RANDOM_PREFILL_LEN = args.random_prefill_len
|
||||
RANDOM_DECODE_LEN = args.random_decode_len
|
||||
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user