Improve the structure of CI (#911)

This commit is contained in:
Ying Sheng
2024-08-03 23:09:21 -07:00
committed by GitHub
parent 539856455d
commit 995af5a54b
29 changed files with 451 additions and 237 deletions

View File

@@ -0,0 +1,9 @@
curl http://localhost:30000/generate \
-H "Content-Type: application/json" \
-d '{
"text": "Once upon a time,",
"sampling_params": {
"max_new_tokens": 16,
"temperature": 0
}
}'

View File

@@ -0,0 +1,215 @@
import pytest
import torch
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention
from sglang.srt.layers.token_attention import token_attention_fwd
flashinfer_prefill_wrapper = None
flashinfer_decode_wrapper = None
@pytest.mark.parametrize("batch_size", [12, 37, 67])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [32, 4])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
init_flashinfer(num_qo_heads, num_kv_heads)
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
v_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
o_triton = torch.empty_like(q)
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
max_len_extend = qo_len
extend_attention_fwd(
q,
k_extend,
v_extend,
o_triton,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
None, # b_start_loc = None
b_seq_len,
None, # b_seq_len_prefix = None
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
o_redundant = torch.empty_like(q)
b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0)
b_seq_len_prefix = b_seq_len - b_seq_len_extend
redundant_attention(
q,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton)))
print("Max: ", torch.max(torch.abs(o_redundant - o_triton)))
assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3)
flashinfer_prefill_wrapper.end_forward()
flashinfer_prefill_wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
o = flashinfer_prefill_wrapper.forward(
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 37])
@pytest.mark.parametrize("kv_len", [54, 127, 537])
@pytest.mark.parametrize("num_kv_heads", [32])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
# to test different shape of decode, change the parameters in the __main__, and run decode only once
init_flashinfer(num_qo_heads, num_kv_heads)
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
o_triton = torch.empty_like(q)
req_to_token = (
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
other_kv_index = 0
token_attention_fwd(
q,
k_buffer,
v_buffer,
o_triton,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_tokens,
)
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type="float16",
)
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
def init_flashinfer(num_attention_heads, num_kv_heads):
if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
use_tensor_cores = True
else:
use_tensor_cores = False
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
global flashinfer_prefill_wrapper, flashinfer_decode_wrapper
flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128)
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128)
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)

View File

@@ -0,0 +1,69 @@
"""
Usage:
python3 -m sglang.launch_server --model-path /model/llama-classification
python3 test_httpserver_classify.py
"""
import argparse
import numpy as np
import requests
def get_logits(url, prompt):
response = requests.post(
url + "/generate",
json={
"text": prompt,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
return response.json()["meta_info"]["normalized_prompt_logprob"]
def get_logits_batch(url, prompts):
response = requests.post(
url + "/generate",
json={
"text": prompts,
"sampling_params": {
"max_new_tokens": 0,
},
"return_logprob": True,
},
)
ret = response.json()
logits = np.array(
list(
ret[i]["meta_info"]["normalized_prompt_logprob"]
for i in range(len(prompts))
)
)
return logits
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
# A single request
prompt = "This is a test prompt.<|eot_id|>"
logits = get_logits(url, prompt)
print(f"{logits=}")
# A batch of requests
prompts = [
"This is a test prompt.<|eot_id|>",
"This is another test prompt.<|eot_id|>",
"This is a long long long long test prompt.<|eot_id|>",
]
logits = get_logits_batch(url, prompts)
print(f"{logits=}")

View File

@@ -0,0 +1,56 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def main(args):
url = f"{args.host}:{args.port}"
task1 = send_request(
url + "/generate",
{
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
delay=1,
)
task2 = send_request(
url + "/generate",
{
"text": "The capital of the United Kindom is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
)
rets = await asyncio.gather(task1, task2)
print(rets)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,55 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import requests
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode(url)
test_decode(url, n=3)
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
test_decode(
url,
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)

View File

@@ -0,0 +1,69 @@
"""
Usage:
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
python3 test_httpserver_decode_stream.py
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import requests
def test_decode_stream(url, return_logprob, top_logprobs_num):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": True,
"logprob_start_len": 0,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if return_logprob:
assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][
"output_token_logprobs"
][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
prev = len(data["meta_info"]["output_token_logprobs"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False, 0)
test_decode_stream(url, True, 0)
test_decode_stream(url, True, 3)

View File

@@ -0,0 +1,88 @@
"""
Usage:
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
python3 test_httpserver_llava.py
Output:
The image features a man standing on the back of a yellow taxi cab, holding
"""
import argparse
import asyncio
import json
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
response = []
for i in range(8):
response.append(
send_request(
url + "/generate",
{
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "example_image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)

View File

@@ -0,0 +1,42 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is Paris.\nThe capital of the United States is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())

View File

@@ -0,0 +1,138 @@
import argparse
from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_jump_forward = (
r"The google's DNS sever address is "
+ IP_REGEX
+ r" and "
+ IP_REGEX
+ r". "
+ r"The google's website domain name is "
+ r"www\.(\w)+\.(\w)+"
+ r"."
)
# fmt: off
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
max_tokens=128,
temperature=0,
regex=ip_jump_forward,
)
# fmt: on
json_jump_forward = (
r"""The information about Hogwarts is in the following JSON format\.\n"""
+ r"""\n\{\n"""
+ r""" "name": "[\w\d\s]*",\n"""
+ r""" "country": "[\w\d\s]*",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
+ r""" "population": [-+]?[0-9]+,\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
+ r"""\}\n"""
)
# fmt: off
@sgl.function
def json_gen(s):
s += sgl.gen(
"json",
max_tokens=128,
temperature=0,
regex=json_jump_forward,
)
# fmt: on
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
@sgl.function
def character_gen(s):
s += "Give me a character description who is a wizard.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Character),
)
def main(args):
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
state = regex_gen.run(temperature=0)
print("=" * 20, "IP TEST", "=" * 20)
print(state.text())
state = json_gen.run(temperature=0)
print("=" * 20, "JSON TEST", "=" * 20)
print(state.text())
state = character_gen.run(temperature=0)
print("=" * 20, "CHARACTER TEST", "=" * 20)
print(state.text())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
main(args)
# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.
# {
# "name": "Hogwarts School of Witchcraft and Wizardry",
# "country": "Scotland",
# "latitude": 55.566667,
# "population": 1000,
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }
# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }

View File

@@ -0,0 +1,209 @@
"""
First run the following command to launch the server.
Note that TinyLlama adopts different chat templates in different versions.
For v0.4, the chat template is chatml.
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
"""
import argparse
import json
import openai
def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
print(response.choices[0].logprobs.top_logprobs)
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
flush=True,
)
print(r.choices[0].logprobs.top_logprobs)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print("=" * 100)
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion(args, echo=False, logprobs=3)
test_completion(args, echo=True, logprobs=3)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=3)
test_completion_stream(args, echo=True, logprobs=3)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)

View File

@@ -0,0 +1,132 @@
import argparse
import random
import string
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
TOKENIZER = None
RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None
def gen_prompt(token_num):
if RANDOM_PREFILL_LEN:
token_num = random.randint(1, token_num)
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(TOKENIZER(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def robust_test_dfs(s, d, args, leaf_states):
if d == 0:
s += "END"
leaf_states.append(s)
return
s += gen_prompt(args.len_prefill)
forks = s.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
for fork_s in forks:
robust_test_dfs(fork_s, d - 1, args, leaf_states)
def robust_test_bfs(s, args, leaf_states):
old_forks = [s]
new_forks = []
for _ in range(args.depth):
for old_fork in old_forks:
old_fork += gen_prompt(args.len_prefill)
forks = old_fork.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
new_forks.extend(forks)
old_forks = new_forks
new_forks = []
for old_fork in old_forks:
old_fork += "END"
leaf_states.append(old_fork)
@sgl.function
def robust_test(s, args):
leaf_states = []
if args.mode == "bfs":
robust_test_bfs(s, args, leaf_states)
else:
robust_test_dfs(s, args.depth, args, leaf_states)
return leaf_states
def main(args):
backend = select_sglang_backend(args)
arguments = [{"args": args} for _ in range(args.num_req)]
states = robust_test.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel
)
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
for state in states:
leaf_states = state.ret_value
for leaf_state in leaf_states:
assert leaf_state.text()[-3:] == "END"
f.write(leaf_state.text()[:-3] + "\n")
if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--num-req", type=int, default=2)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--num-fork", type=int, default=2)
parser.add_argument("--len-prefill", type=int, default=128)
parser.add_argument("--len-decode", type=int, default=128)
parser.add_argument("--random-prefill-len", action="store_true")
parser.add_argument("--random-decode-len", action="store_true")
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = add_common_sglang_args_and_parse(parser)
# fmt: on
RANDOM_PREFILL_LEN = args.random_prefill_len
RANDOM_DECODE_LEN = args.random_decode_len
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
random.seed(args.seed)
main(args)

View File

@@ -1,8 +0,0 @@
isort python
black python
isort test
black test
isort benchmark
black benchmark

View File

@@ -1,6 +0,0 @@
docker run --name tgi --rm -ti --gpus all --network host \
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
ghcr.io/huggingface/text-generation-inference:1.3.0 \
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
--max-input-length 2048 --max-total-tokens 4096 \
--port 24000