release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
1
test/killall_python.sh
Normal file
1
test/killall_python.sh
Normal file
@@ -0,0 +1 @@
|
||||
kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}')
|
||||
60
test/lang/run_all.py
Normal file
60
test/lang/run_all.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from sglang.utils import run_with_timeout
|
||||
|
||||
|
||||
def run_unittest_files(files, args):
|
||||
for filename in files:
|
||||
|
||||
def func():
|
||||
print(filename)
|
||||
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
|
||||
|
||||
p = multiprocessing.Process(target=func)
|
||||
|
||||
def run_one_file():
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
try:
|
||||
run_with_timeout(run_one_file, timeout=args.time_limit_per_file)
|
||||
if p.exitcode != 0:
|
||||
return False
|
||||
except TimeoutError:
|
||||
p.terminate()
|
||||
time.sleep(5)
|
||||
print(
|
||||
f"\nTimeout after {args.time_limit_per_file} seconds "
|
||||
f"when running {filename}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--time-limit-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
files = glob.glob("**/test_*.py", recursive=True)
|
||||
|
||||
tic = time.time()
|
||||
success = run_unittest_files(files, args)
|
||||
|
||||
if success:
|
||||
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
|
||||
else:
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
exit(0 if success else -1)
|
||||
35
test/lang/test_anthropic_backend.py
Normal file
35
test/lang/test_anthropic_backend.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_programs import test_mt_bench, test_stream
|
||||
|
||||
from sglang import Anthropic, set_default_backend
|
||||
|
||||
|
||||
class TestAnthropicBackend(unittest.TestCase):
|
||||
backend = None
|
||||
chat_backend = None
|
||||
|
||||
def setUp(self):
|
||||
cls = type(self)
|
||||
|
||||
if cls.backend is None:
|
||||
cls.backend = Anthropic("claude-2")
|
||||
set_default_backend(cls.backend)
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# from sglang.global_config import global_config
|
||||
|
||||
# global_config.verbosity = 2
|
||||
# t = TestAnthropicBackend()
|
||||
# t.setUp()
|
||||
# t.test_mt_bench()
|
||||
54
test/lang/test_bind_pin.py
Normal file
54
test/lang/test_bind_pin.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import unittest
|
||||
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
class TestBind(unittest.TestCase):
|
||||
backend = None
|
||||
|
||||
def setUp(self):
|
||||
cls = type(self)
|
||||
|
||||
if cls.backend is None:
|
||||
cls.backend = RuntimeEndpoint(base_url="http://localhost:30000")
|
||||
|
||||
def test_bind(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, prompt, question):
|
||||
s += prompt
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
few_shot_qa_2 = few_shot_qa.bind(
|
||||
prompt="The following are questions with answers.\n\n"
|
||||
)
|
||||
|
||||
tracer = few_shot_qa_2.trace()
|
||||
print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_pin(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, prompt, question):
|
||||
s += prompt
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
few_shot_qa_2 = few_shot_qa.bind(
|
||||
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
|
||||
)
|
||||
few_shot_qa_2.pin(self.backend)
|
||||
few_shot_qa_2.unpin(self.backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# t = TestBind()
|
||||
# t.setUp()
|
||||
# t.test_pin()
|
||||
91
test/lang/test_openai_backend.py
Normal file
91
test/lang/test_openai_backend.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_programs import (
|
||||
test_decode_int,
|
||||
test_decode_json,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_image_qa,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_parallel_encoding,
|
||||
test_react,
|
||||
test_select,
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
|
||||
from sglang import OpenAI, set_default_backend
|
||||
|
||||
|
||||
class TestOpenAIBackend(unittest.TestCase):
|
||||
backend = None
|
||||
chat_backend = None
|
||||
chat_vision_backend = None
|
||||
|
||||
def setUp(self):
|
||||
cls = type(self)
|
||||
|
||||
if cls.backend is None:
|
||||
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||
cls.chat_backend = OpenAI("gpt-3.5-turbo")
|
||||
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview")
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
set_default_backend(self.backend)
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
set_default_backend(self.chat_backend)
|
||||
test_mt_bench()
|
||||
|
||||
def test_select(self):
|
||||
set_default_backend(self.backend)
|
||||
test_select(check_answer=True)
|
||||
|
||||
def test_decode_int(self):
|
||||
set_default_backend(self.backend)
|
||||
test_decode_int()
|
||||
|
||||
def test_decode_json(self):
|
||||
set_default_backend(self.backend)
|
||||
test_decode_json()
|
||||
|
||||
def test_expert_answer(self):
|
||||
set_default_backend(self.backend)
|
||||
test_expert_answer()
|
||||
|
||||
def test_tool_use(self):
|
||||
set_default_backend(self.backend)
|
||||
test_tool_use()
|
||||
|
||||
def test_react(self):
|
||||
set_default_backend(self.backend)
|
||||
test_react()
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
set_default_backend(self.backend)
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_parallel_encoding(self):
|
||||
set_default_backend(self.backend)
|
||||
test_parallel_encoding()
|
||||
|
||||
def test_image_qa(self):
|
||||
set_default_backend(self.chat_vision_backend)
|
||||
test_image_qa()
|
||||
|
||||
def test_stream(self):
|
||||
set_default_backend(self.backend)
|
||||
test_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# from sglang.global_config import global_config
|
||||
|
||||
# global_config.verbosity = 2
|
||||
# t = TestOpenAIBackend()
|
||||
# t.setUp()
|
||||
# t.test_decode_json()
|
||||
74
test/lang/test_srt_backend.py
Normal file
74
test/lang/test_srt_backend.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
"""
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_programs import (
|
||||
test_decode_int,
|
||||
test_decode_json,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_parallel_encoding,
|
||||
test_react,
|
||||
test_regex,
|
||||
test_select,
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
class TestSRTBackend(unittest.TestCase):
|
||||
backend = None
|
||||
|
||||
def setUp(self):
|
||||
cls = type(self)
|
||||
|
||||
if cls.backend is None:
|
||||
cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000")
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_select(self):
|
||||
test_select(check_answer=False)
|
||||
|
||||
def test_decode_int(self):
|
||||
test_decode_int()
|
||||
|
||||
def test_expert_answer(self):
|
||||
test_expert_answer()
|
||||
|
||||
def test_tool_use(self):
|
||||
test_tool_use()
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
def test_regex(self):
|
||||
test_regex()
|
||||
|
||||
# def test_parallel_encoding(self):
|
||||
# test_parallel_encoding(check_answer=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# from sglang.global_config import global_config
|
||||
|
||||
# global_config.verbosity = 2
|
||||
# t = TestSRTBackend()
|
||||
# t.setUp()
|
||||
# t.test_regex()
|
||||
132
test/lang/test_tracing.py
Normal file
132
test/lang/test_tracing.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import unittest
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
class TestTracing(unittest.TestCase):
|
||||
def test_few_shot_qa(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, question):
|
||||
s += "The following are questions with answers.\n\n"
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
tracer = few_shot_qa.trace()
|
||||
print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_select(self):
|
||||
@sgl.function
|
||||
def capital(s):
|
||||
s += "The capital of France is"
|
||||
s += sgl.select("capital", ["Paris. ", "London. "])
|
||||
s += "It is a city" + sgl.gen("description", stop=".")
|
||||
|
||||
tracer = capital.trace()
|
||||
print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_raise_warning(self):
|
||||
@sgl.function
|
||||
def wrong(s, question):
|
||||
s += f"I want to ask {question}"
|
||||
|
||||
try:
|
||||
tracer = wrong.trace()
|
||||
raised = False
|
||||
except TypeError:
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
|
||||
def test_multi_function(self):
|
||||
@sgl.function
|
||||
def expand(s, tip):
|
||||
s += (
|
||||
"Please expand the following tip into a detailed paragraph:"
|
||||
+ tip
|
||||
+ "\n"
|
||||
)
|
||||
s += sgl.gen("detailed_tip")
|
||||
|
||||
@sgl.function
|
||||
def tip_suggestion(s, topic):
|
||||
s += "Here are 2 tips for " + topic + ".\n"
|
||||
|
||||
s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
|
||||
s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"
|
||||
|
||||
branch1 = expand(tip=s["tip_1"])
|
||||
branch2 = expand(tip=s["tip_2"])
|
||||
|
||||
s += "Tip 1: " + branch1["detailed_tip"] + "\n"
|
||||
s += "Tip 2: " + branch2["detailed_tip"] + "\n"
|
||||
s += "In summary" + sgl.gen("summary")
|
||||
|
||||
compiled = tip_suggestion.compile()
|
||||
compiled.print_graph()
|
||||
|
||||
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
||||
state = compiled.run(topic="staying healthy")
|
||||
print(state.text() + "\n")
|
||||
|
||||
states = compiled.run_batch(
|
||||
[
|
||||
{"topic": "staying healthy"},
|
||||
{"topic": "staying happy"},
|
||||
{"topic": "earning money"},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
for s in states:
|
||||
print(s.text() + "\n")
|
||||
|
||||
def test_role(self):
|
||||
@sgl.function
|
||||
def multi_turn_chat(s):
|
||||
s += sgl.user("Who are you?")
|
||||
s += sgl.assistant(sgl.gen("answer_1"))
|
||||
s += sgl.user("Who created you?")
|
||||
s += sgl.assistant(sgl.gen("answer_2"))
|
||||
|
||||
backend = BaseBackend()
|
||||
backend.chat_template = get_chat_template("llama-2-chat")
|
||||
|
||||
compiled = multi_turn_chat.compile(backend=backend)
|
||||
compiled.print_graph()
|
||||
|
||||
def test_fork(self):
|
||||
@sgl.function
|
||||
def tip_suggestion(s):
|
||||
s += (
|
||||
"Here are three tips for staying healthy: "
|
||||
"1. Balanced Diet; "
|
||||
"2. Regular Exercise; "
|
||||
"3. Adequate Sleep\n"
|
||||
)
|
||||
|
||||
forks = s.fork(3)
|
||||
for i in range(3):
|
||||
forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
|
||||
forks[i] += sgl.gen(f"detailed_tip")
|
||||
|
||||
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
|
||||
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
|
||||
s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
|
||||
s += "In summary" + sgl.gen("summary")
|
||||
|
||||
tracer = tip_suggestion.trace()
|
||||
print(tracer.last_node.print_graph_dfs())
|
||||
|
||||
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
||||
print(a.text())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# t = TestTracing()
|
||||
# t.test_fork()
|
||||
274
test/srt/model/bench_llama_low_api.py
Normal file
274
test/srt/model/bench_llama_low_api.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchBatch:
|
||||
req_to_token_pool: torch.Tensor
|
||||
token_to_kv_pool: torch.Tensor
|
||||
|
||||
input_ids: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
self.req_to_token_pool = model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
def init_prefill_batch(self, input_ids, batch_size, seq_len):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
|
||||
|
||||
for i in range(batch_size):
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
|
||||
i * seq_len : (i + 1) * seq_len
|
||||
]
|
||||
|
||||
def update_extend(
|
||||
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.full(
|
||||
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
|
||||
|
||||
req_to_token = self.req_to_token_pool.req_to_token
|
||||
fork_num = batch_size // prefix_req_idx.shape[0]
|
||||
for i in range(batch_size):
|
||||
p_idx = prefix_req_idx[i // fork_num].item()
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||
req_to_token[
|
||||
n_idx, prefix_len : prefix_len + extend_len
|
||||
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||
|
||||
def update_decode(self, predict_ids, batch_size):
|
||||
assert predict_ids.shape[0] == batch_size
|
||||
assert batch_size == self.req_pool_indices.shape[0]
|
||||
|
||||
self.input_ids = predict_ids.reshape(-1)
|
||||
self.prefix_lens = None
|
||||
(
|
||||
self.out_cache_loc,
|
||||
self.out_cache_cont_start,
|
||||
self.out_cache_cont_end,
|
||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens
|
||||
] = self.out_cache_loc
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
|
||||
def prefill(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
False,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def extend(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
True,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def decode(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits = model_runner.forward_decode(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
None,
|
||||
batch.position_ids_offsets,
|
||||
None,
|
||||
batch.out_cache_cont_start,
|
||||
batch.out_cache_cont_end,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def bench_generate_worker(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
):
|
||||
assert unique_num % shared_num == 0
|
||||
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=0.8,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
nccl_port=28888,
|
||||
model_mode=model_mode,
|
||||
)
|
||||
|
||||
batch = BenchBatch(model_runner)
|
||||
|
||||
# warm up
|
||||
for _ in range(1):
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(shared_num * shared_len,)
|
||||
).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(unique_num * unique_len,)
|
||||
).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
|
||||
for i in range(decode_len):
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
if tp_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
prefill_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
|
||||
|
||||
extend_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
|
||||
|
||||
for i in range(decode_len):
|
||||
decode_start = time.time()
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
|
||||
|
||||
|
||||
def bench_generate(
|
||||
model_path,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
):
|
||||
print(
|
||||
f"tp_size: {tp_size}, "
|
||||
f"shared_num: {shared_num}, "
|
||||
f"unique_num: {unique_num}, "
|
||||
f"shared_len: {shared_len}, "
|
||||
f"unique_len: {unique_len}, "
|
||||
f"decode_len: {decode_len}, "
|
||||
f"model_mode: {model_mode}"
|
||||
)
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = mp.Process(
|
||||
target=bench_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_generate(
|
||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||
tp_size=1,
|
||||
shared_num=1,
|
||||
unique_num=32,
|
||||
shared_len=256,
|
||||
unique_len=256,
|
||||
decode_len=8,
|
||||
model_mode=[],
|
||||
)
|
||||
80
test/srt/model/reference_hf.py
Normal file
80
test/srt/model/reference_hf.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def normal_text(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
|
||||
print(m)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = 32
|
||||
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
output_ids = m.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
print(output_str)
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
print("prefill logits", prefill_logits)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthetic_tokens(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
print(m)
|
||||
|
||||
input_len = 256
|
||||
output_len = 8
|
||||
prompts = [list(range(5, 5 + input_len))]
|
||||
|
||||
for p in prompts:
|
||||
input_ids = p
|
||||
for i in range(output_len + 1):
|
||||
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
|
||||
0
|
||||
][-1]
|
||||
|
||||
if i == 0:
|
||||
print("prefill logits", prefill_logits)
|
||||
else:
|
||||
print("decode", i - 1, prefill_logits)
|
||||
|
||||
input_ids.append(torch.argmax(prefill_logits).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
|
||||
# default="meta-llama/Llama-2-7b-chat-hf",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
normal_text(args)
|
||||
# synthetic_tokens(args)
|
||||
108
test/srt/model/test_llama_extend.py
Normal file
108
test/srt/model/test_llama_extend.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
|
||||
def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Input
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
cut_num = 4
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
req = Req(i)
|
||||
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||
req.sampling_params = sampling_params
|
||||
reqs.append(req)
|
||||
|
||||
# Prefill
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
print("extend logits (first)", logits)
|
||||
|
||||
# Extend
|
||||
for i in range(len(prompts)):
|
||||
req = reqs[i]
|
||||
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
|
||||
req.prefix_indices = model.req_to_token_pool.req_to_token[
|
||||
batch.req_pool_indices[i], :cut_num
|
||||
]
|
||||
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.init_extend_batch(model.model_config.vocab_size(), None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
print("extend logits", logits)
|
||||
print(
|
||||
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
|
||||
)
|
||||
|
||||
# Decode
|
||||
for i in range(6):
|
||||
batch.update_for_decode(next_token_ids.cpu().numpy())
|
||||
logits = model.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
print(
|
||||
"next_token_ids",
|
||||
next_token_ids,
|
||||
[tokenizer.decode(x) for x in next_token_ids],
|
||||
)
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4
|
||||
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
|
||||
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
|
||||
# device='cuda:0', dtype=torch.float16)
|
||||
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
|
||||
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
|
||||
# device='cuda:0', dtype=torch.float16)
|
||||
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
|
||||
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
|
||||
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
|
||||
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
|
||||
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
|
||||
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
|
||||
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']
|
||||
209
test/srt/model/test_llama_low_api.py
Normal file
209
test/srt/model/test_llama_low_api.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
|
||||
def test_generate_worker(
|
||||
model_path, tp_rank, tp_size, batch_size, input_len, output_len
|
||||
):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
|
||||
# Prepare data
|
||||
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
|
||||
input_ids = input_ids.reshape(-1)
|
||||
input_ids = torch.tensor(input_ids).cuda()
|
||||
|
||||
def init_batch_data(model, batch_size, input_len):
|
||||
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
|
||||
seq_lens = torch.full(
|
||||
(batch_size,), input_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
|
||||
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
|
||||
for i in range(batch_size):
|
||||
req_idx = req_pool_indices[i].item()
|
||||
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
|
||||
i * input_len : (i + 1) * input_len
|
||||
]
|
||||
|
||||
return (
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
def prefill(print_logits):
|
||||
nonlocal predict_ids
|
||||
|
||||
logits, _ = model.forward_prefill(
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
if print_logits and tp_rank == 0:
|
||||
print("prefill logits", logits, logits.shape)
|
||||
|
||||
def decode(print_logits):
|
||||
nonlocal predict_ids
|
||||
|
||||
(
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
|
||||
seq_lens.add_(1)
|
||||
logits = model.forward_decode(
|
||||
torch.from_numpy(predict_ids).cuda().reshape(-1),
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
position_ids_offsets,
|
||||
None,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
if print_logits and tp_rank == 0:
|
||||
print("decode", i, logits)
|
||||
|
||||
# Warm up
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = init_batch_data(model, batch_size, input_len)
|
||||
predict_ids = None
|
||||
|
||||
prefill(True)
|
||||
for i in range(output_len):
|
||||
decode(True)
|
||||
|
||||
for i in range(batch_size):
|
||||
req_idx = req_pool_indices[i].item()
|
||||
model.token_to_kv_pool.free(
|
||||
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
|
||||
)
|
||||
model.req_to_token_pool.free(req_pool_indices)
|
||||
|
||||
# Benchmark
|
||||
if tp_size > 1:
|
||||
dist.barrier()
|
||||
start_time = prefill_start_time = time.time()
|
||||
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = init_batch_data(model, batch_size, input_len)
|
||||
|
||||
prefill(False)
|
||||
|
||||
if tp_rank == 0:
|
||||
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
|
||||
|
||||
for i in range(output_len):
|
||||
step_start = time.time()
|
||||
|
||||
decode(False)
|
||||
|
||||
step_end = time.time()
|
||||
|
||||
if i % 100 == 0 or i == output_len - 1:
|
||||
if tp_rank == 0:
|
||||
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if tp_rank == 0:
|
||||
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
batch_size,
|
||||
input_len,
|
||||
output_len,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
|
||||
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
|
||||
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
|
||||
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
|
||||
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
|
||||
# device='cuda:0')
|
||||
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
|
||||
# device='cuda:0')
|
||||
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
|
||||
# device='cuda:0')
|
||||
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
|
||||
# device='cuda:0')
|
||||
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
|
||||
# device='cuda:0')
|
||||
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
|
||||
# device='cuda:0')
|
||||
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
|
||||
# device='cuda:0')
|
||||
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
|
||||
# device='cuda:0')
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
|
||||
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
|
||||
# device='cuda:0')
|
||||
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
|
||||
# device='cuda:0')
|
||||
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
|
||||
# device='cuda:0')
|
||||
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
|
||||
# device='cuda:0')
|
||||
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
|
||||
# device='cuda:0')
|
||||
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
|
||||
# device='cuda:0')
|
||||
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
|
||||
# device='cuda:0')
|
||||
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
|
||||
# device='cuda:0')
|
||||
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
|
||||
# device='cuda:0')
|
||||
161
test/srt/model/test_llava_low_api.py
Normal file
161
test/srt/model/test_llava_low_api.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.managers.router.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
def init_batch_data(model, batch_size, input_len):
|
||||
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
|
||||
seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda")
|
||||
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
|
||||
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
|
||||
for i in range(batch_size):
|
||||
model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[
|
||||
i * input_len : (i + 1) * input_len
|
||||
]
|
||||
|
||||
return (
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
|
||||
def prefill(model, tp_rank, params, print_logits):
|
||||
logits, _ = model.forward_extend_multi_modal(
|
||||
*params,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
if print_logits and tp_rank == 0:
|
||||
print("prefill logits", logits, logits.shape)
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits):
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = params
|
||||
|
||||
(
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
|
||||
seq_lens.add_(1)
|
||||
logits = model.forward_decode(
|
||||
torch.from_numpy(predict_ids).cuda().reshape(-1),
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
position_ids_offsets,
|
||||
None,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
if print_logits and tp_rank == 0:
|
||||
print("decode", step, logits)
|
||||
return predict_ids
|
||||
|
||||
|
||||
def test_generate_worker(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
# print(model.model)
|
||||
|
||||
# Prepare data
|
||||
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:"
|
||||
image_path = "/home/ubuntu/sglang/test/lang/image.png"
|
||||
image = load_image(image_path)
|
||||
|
||||
processor = get_processor("llava-hf/llava-1.5-7b-hf")
|
||||
input_ids = processor.tokenizer.encode(prompt)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
input_ids, offset = model.model.pad_input_ids(
|
||||
input_ids,
|
||||
[
|
||||
0,
|
||||
],
|
||||
)
|
||||
|
||||
params = init_batch_data(model, 1, len(input_ids))
|
||||
|
||||
# inference
|
||||
output_ids = []
|
||||
prefill_params = (
|
||||
torch.tensor(np.array(input_ids)).cuda(),
|
||||
np.array(pixel_values),
|
||||
[offset],
|
||||
*params,
|
||||
)
|
||||
predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False)
|
||||
output_ids.append(predict_ids[0][0])
|
||||
for i in range(16):
|
||||
predict_ids = decode(
|
||||
i,
|
||||
model,
|
||||
tp_rank=0,
|
||||
batch_size=1,
|
||||
predict_ids=predict_ids,
|
||||
params=params,
|
||||
print_logits=False,
|
||||
)
|
||||
output_ids.append(predict_ids[0][0])
|
||||
|
||||
# detokenization
|
||||
output = processor.tokenizer.batch_decode(
|
||||
[output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
assert (
|
||||
output
|
||||
== "The image features a man standing on the back of a yellow taxi cab, holding"
|
||||
)
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_generate("liuhaotian/llava-v1.5-7b", 1)
|
||||
163
test/srt/test_flashinfer.py
Normal file
163
test/srt/test_flashinfer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 37, 67])
|
||||
@pytest.mark.parametrize("kv_len", [54, 97])
|
||||
@pytest.mark.parametrize("qo_len", [37, 17])
|
||||
@pytest.mark.parametrize("num_kv_heads", [4])
|
||||
@pytest.mark.parametrize("num_qo_heads", [4, 32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
@pytest.mark.parametrize("use_wrapper", [True, False])
|
||||
def test_batch_prefill_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
qo_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
use_wrapper,
|
||||
):
|
||||
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
|
||||
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
v_extend = (
|
||||
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
|
||||
.contiguous()
|
||||
.view(-1, num_kv_heads, head_dim)
|
||||
)
|
||||
o_triton = torch.empty_like(q)
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
|
||||
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
max_len_extend = qo_len
|
||||
|
||||
extend_attention_fwd(
|
||||
q,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_triton,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
None, # b_start_loc = None
|
||||
b_seq_len,
|
||||
None, # b_seq_len_prefix = None
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
if use_wrapper:
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper()
|
||||
wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads)
|
||||
o = wrapper.forward(
|
||||
q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len
|
||||
)
|
||||
else:
|
||||
o = flashinfer.batch_prefill_with_paged_kv_cache(
|
||||
q,
|
||||
q_indptr,
|
||||
kv_data,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [12, 17, 37])
|
||||
@pytest.mark.parametrize("kv_len", [54, 127, 537])
|
||||
@pytest.mark.parametrize("num_kv_heads", [32])
|
||||
@pytest.mark.parametrize("num_qo_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_batch_decode_with_paged_kv_cache(
|
||||
batch_size,
|
||||
kv_len,
|
||||
num_kv_heads,
|
||||
num_qo_heads,
|
||||
head_dim,
|
||||
):
|
||||
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
|
||||
# to test different shape of decode, change the parameters in the __main__, and run decode only once
|
||||
|
||||
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
|
||||
total_tokens = kv_len * batch_size
|
||||
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
|
||||
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
|
||||
kv_indices = torch.arange(0, total_tokens).to(0).int()
|
||||
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
|
||||
|
||||
# init args for triton kernel
|
||||
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
|
||||
o_triton = torch.empty_like(q)
|
||||
req_to_token = (
|
||||
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
|
||||
)
|
||||
b_req_idx = torch.arange(0, batch_size).to(0).int()
|
||||
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
|
||||
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
|
||||
max_len_in_batch = kv_len
|
||||
other_kv_index = 0
|
||||
token_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_triton,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper()
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_last_page_len,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1,
|
||||
"NONE",
|
||||
"float16",
|
||||
)
|
||||
o = wrapper.forward(q, kv_data, kv_indptr, kv_indices, kv_last_page_len)
|
||||
|
||||
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
|
||||
print("Max: ", torch.max(torch.abs(o - o_triton)))
|
||||
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128, False)
|
||||
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128, True)
|
||||
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)
|
||||
56
test/srt/test_httpserver_concurrent.py
Normal file
56
test/srt/test_httpserver_concurrent.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
|
||||
|
||||
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def main(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
task1 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
delay=1,
|
||||
)
|
||||
|
||||
task2 = send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "The capital of the United Kindom is",
|
||||
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
|
||||
},
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(task1, task2)
|
||||
print(rets)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args))
|
||||
31
test/srt/test_httpserver_decode.py
Normal file
31
test/srt/test_httpserver_decode.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
42
test/srt/test_httpserver_decode_stream.py
Normal file
42
test/srt/test_httpserver_decode_stream.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 1024,
|
||||
},
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
84
test/srt/test_httpserver_llava.py
Normal file
84
test/srt/test_httpserver_llava.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
|
||||
Output:
|
||||
The image features a man standing on the back of a yellow taxi cab, holding
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = []
|
||||
for i in range(8):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
|
||||
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 128,
|
||||
},
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_concurrent(args))
|
||||
|
||||
test_streaming(args)
|
||||
43
test/srt/test_httpserver_reuse.py
Normal file
43
test/srt/test_httpserver_reuse.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
|
||||
|
||||
Output:
|
||||
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is Paris.\nThe capital of the United States is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
print(response.json())
|
||||
Reference in New Issue
Block a user