release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

1
test/killall_python.sh Normal file
View File

@@ -0,0 +1 @@
kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}')

60
test/lang/run_all.py Normal file
View File

@@ -0,0 +1,60 @@
import argparse
import glob
import multiprocessing
import os
import time
import unittest
from sglang.utils import run_with_timeout
def run_unittest_files(files, args):
for filename in files:
def func():
print(filename)
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
p.start()
p.join()
try:
run_with_timeout(run_one_file, timeout=args.time_limit_per_file)
if p.exitcode != 0:
return False
except TimeoutError:
p.terminate()
time.sleep(5)
print(
f"\nTimeout after {args.time_limit_per_file} seconds "
f"when running {filename}"
)
return False
return True
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--time-limit-per-file",
type=int,
default=1000,
help="The time limit for running one file in seconds.",
)
args = arg_parser.parse_args()
files = glob.glob("**/test_*.py", recursive=True)
tic = time.time()
success = run_unittest_files(files, args)
if success:
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
else:
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
exit(0 if success else -1)

View File

@@ -0,0 +1,35 @@
import json
import unittest
from sglang.test.test_programs import test_mt_bench, test_stream
from sglang import Anthropic, set_default_backend
class TestAnthropicBackend(unittest.TestCase):
backend = None
chat_backend = None
def setUp(self):
cls = type(self)
if cls.backend is None:
cls.backend = Anthropic("claude-2")
set_default_backend(cls.backend)
def test_mt_bench(self):
test_mt_bench()
def test_stream(self):
test_stream()
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestAnthropicBackend()
# t.setUp()
# t.test_mt_bench()

View File

@@ -0,0 +1,54 @@
import unittest
from sglang.backend.runtime_endpoint import RuntimeEndpoint
import sglang as sgl
class TestBind(unittest.TestCase):
backend = None
def setUp(self):
cls = type(self)
if cls.backend is None:
cls.backend = RuntimeEndpoint(base_url="http://localhost:30000")
def test_bind(self):
@sgl.function
def few_shot_qa(s, prompt, question):
s += prompt
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
few_shot_qa_2 = few_shot_qa.bind(
prompt="The following are questions with answers.\n\n"
)
tracer = few_shot_qa_2.trace()
print(tracer.last_node.print_graph_dfs() + "\n")
def test_pin(self):
@sgl.function
def few_shot_qa(s, prompt, question):
s += prompt
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
few_shot_qa_2 = few_shot_qa.bind(
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
)
few_shot_qa_2.pin(self.backend)
few_shot_qa_2.unpin(self.backend)
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestBind()
# t.setUp()
# t.test_pin()

View File

@@ -0,0 +1,91 @@
import unittest
from sglang.test.test_programs import (
test_decode_int,
test_decode_json,
test_expert_answer,
test_few_shot_qa,
test_image_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_react,
test_select,
test_stream,
test_tool_use,
)
from sglang import OpenAI, set_default_backend
class TestOpenAIBackend(unittest.TestCase):
backend = None
chat_backend = None
chat_vision_backend = None
def setUp(self):
cls = type(self)
if cls.backend is None:
cls.backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-vision-preview")
def test_few_shot_qa(self):
set_default_backend(self.backend)
test_few_shot_qa()
def test_mt_bench(self):
set_default_backend(self.chat_backend)
test_mt_bench()
def test_select(self):
set_default_backend(self.backend)
test_select(check_answer=True)
def test_decode_int(self):
set_default_backend(self.backend)
test_decode_int()
def test_decode_json(self):
set_default_backend(self.backend)
test_decode_json()
def test_expert_answer(self):
set_default_backend(self.backend)
test_expert_answer()
def test_tool_use(self):
set_default_backend(self.backend)
test_tool_use()
def test_react(self):
set_default_backend(self.backend)
test_react()
def test_parallel_decoding(self):
set_default_backend(self.backend)
test_parallel_decoding()
def test_parallel_encoding(self):
set_default_backend(self.backend)
test_parallel_encoding()
def test_image_qa(self):
set_default_backend(self.chat_vision_backend)
test_image_qa()
def test_stream(self):
set_default_backend(self.backend)
test_stream()
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUp()
# t.test_decode_json()

View File

@@ -0,0 +1,74 @@
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import json
import unittest
from sglang.test.test_programs import (
test_decode_int,
test_decode_json,
test_expert_answer,
test_few_shot_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_react,
test_regex,
test_select,
test_stream,
test_tool_use,
)
import sglang as sgl
class TestSRTBackend(unittest.TestCase):
backend = None
def setUp(self):
cls = type(self)
if cls.backend is None:
cls.backend = sgl.RuntimeEndpoint(base_url="http://localhost:30000")
sgl.set_default_backend(cls.backend)
def test_few_shot_qa(self):
test_few_shot_qa()
def test_mt_bench(self):
test_mt_bench()
def test_select(self):
test_select(check_answer=False)
def test_decode_int(self):
test_decode_int()
def test_expert_answer(self):
test_expert_answer()
def test_tool_use(self):
test_tool_use()
def test_parallel_decoding(self):
test_parallel_decoding()
def test_stream(self):
test_stream()
def test_regex(self):
test_regex()
# def test_parallel_encoding(self):
# test_parallel_encoding(check_answer=False)
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestSRTBackend()
# t.setUp()
# t.test_regex()

132
test/lang/test_tracing.py Normal file
View File

@@ -0,0 +1,132 @@
import unittest
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
import sglang as sgl
class TestTracing(unittest.TestCase):
def test_few_shot_qa(self):
@sgl.function
def few_shot_qa(s, question):
s += "The following are questions with answers.\n\n"
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
tracer = few_shot_qa.trace()
print(tracer.last_node.print_graph_dfs() + "\n")
def test_select(self):
@sgl.function
def capital(s):
s += "The capital of France is"
s += sgl.select("capital", ["Paris. ", "London. "])
s += "It is a city" + sgl.gen("description", stop=".")
tracer = capital.trace()
print(tracer.last_node.print_graph_dfs() + "\n")
def test_raise_warning(self):
@sgl.function
def wrong(s, question):
s += f"I want to ask {question}"
try:
tracer = wrong.trace()
raised = False
except TypeError:
raised = True
assert raised
def test_multi_function(self):
@sgl.function
def expand(s, tip):
s += (
"Please expand the following tip into a detailed paragraph:"
+ tip
+ "\n"
)
s += sgl.gen("detailed_tip")
@sgl.function
def tip_suggestion(s, topic):
s += "Here are 2 tips for " + topic + ".\n"
s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"
branch1 = expand(tip=s["tip_1"])
branch2 = expand(tip=s["tip_2"])
s += "Tip 1: " + branch1["detailed_tip"] + "\n"
s += "Tip 2: " + branch2["detailed_tip"] + "\n"
s += "In summary" + sgl.gen("summary")
compiled = tip_suggestion.compile()
compiled.print_graph()
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
state = compiled.run(topic="staying healthy")
print(state.text() + "\n")
states = compiled.run_batch(
[
{"topic": "staying healthy"},
{"topic": "staying happy"},
{"topic": "earning money"},
],
temperature=0,
)
for s in states:
print(s.text() + "\n")
def test_role(self):
@sgl.function
def multi_turn_chat(s):
s += sgl.user("Who are you?")
s += sgl.assistant(sgl.gen("answer_1"))
s += sgl.user("Who created you?")
s += sgl.assistant(sgl.gen("answer_2"))
backend = BaseBackend()
backend.chat_template = get_chat_template("llama-2-chat")
compiled = multi_turn_chat.compile(backend=backend)
compiled.print_graph()
def test_fork(self):
@sgl.function
def tip_suggestion(s):
s += (
"Here are three tips for staying healthy: "
"1. Balanced Diet; "
"2. Regular Exercise; "
"3. Adequate Sleep\n"
)
forks = s.fork(3)
for i in range(3):
forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
forks[i] += sgl.gen(f"detailed_tip")
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
s += "In summary" + sgl.gen("summary")
tracer = tip_suggestion.trace()
print(tracer.last_node.print_graph_dfs())
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
print(a.text())
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestTracing()
# t.test_fork()

View File

@@ -0,0 +1,274 @@
import multiprocessing as mp
import time
from dataclasses import dataclass
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
@dataclass
class BenchBatch:
req_to_token_pool: torch.Tensor
token_to_kv_pool: torch.Tensor
input_ids: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
req_pool_indices: torch.Tensor = None
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
def __init__(self, model_runner: ModelRunner):
self.req_to_token_pool = model_runner.req_to_token_pool
self.token_to_kv_pool = model_runner.token_to_kv_pool
def init_prefill_batch(self, input_ids, batch_size, seq_len):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
for i in range(batch_size):
n_idx = self.req_pool_indices[i].item()
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
i * seq_len : (i + 1) * seq_len
]
def update_extend(
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.full(
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
)
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
req_to_token = self.req_to_token_pool.req_to_token
fork_num = batch_size // prefix_req_idx.shape[0]
for i in range(batch_size):
p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[
n_idx, prefix_len : prefix_len + extend_len
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size
assert batch_size == self.req_pool_indices.shape[0]
self.input_ids = predict_ids.reshape(-1)
self.prefix_lens = None
(
self.out_cache_loc,
self.out_cache_cont_start,
self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens
] = self.out_cache_loc
self.seq_lens.add_(1)
def prefill(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def extend(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
True,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def decode(model_runner: ModelRunner, batch: BenchBatch):
logits = model_runner.forward_decode(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
None,
batch.position_ids_offsets,
None,
batch.out_cache_cont_start,
batch.out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def bench_generate_worker(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
):
assert unique_num % shared_num == 0
model_config = ModelConfig(path=model_path)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=0.8,
tp_rank=tp_rank,
tp_size=tp_size,
nccl_port=28888,
model_mode=model_mode,
)
batch = BenchBatch(model_runner)
# warm up
for _ in range(1):
input_ids = torch.randint(
low=5, high=100, size=(shared_num * shared_len,)
).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
input_ids = torch.randint(
low=5, high=100, size=(unique_num * unique_len,)
).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
for i in range(decode_len):
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
if tp_size > 1:
dist.barrier()
prefill_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
if tp_rank == 0:
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
extend_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
if tp_rank == 0:
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
for i in range(decode_len):
decode_start = time.time()
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
if tp_rank == 0:
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
def bench_generate(
model_path,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
):
print(
f"tp_size: {tp_size}, "
f"shared_num: {shared_num}, "
f"unique_num: {unique_num}, "
f"shared_len: {shared_len}, "
f"unique_len: {unique_len}, "
f"decode_len: {decode_len}, "
f"model_mode: {model_mode}"
)
workers = []
for tp_rank in range(tp_size):
proc = mp.Process(
target=bench_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
bench_generate(
model_path="meta-llama/Llama-2-7b-chat-hf",
tp_size=1,
shared_num=1,
unique_num=32,
shared_len=256,
unique_len=256,
decode_len=8,
model_mode=[],
)

View File

@@ -0,0 +1,80 @@
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@torch.inference_mode()
def normal_text(args):
t = AutoTokenizer.from_pretrained(args.model_path)
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = 32
for p in prompts:
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").cuda()
else:
input_ids = torch.tensor([p], device="cuda")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
print(output_str)
prefill_logits = m.forward(input_ids).logits[0][-1]
print("prefill logits", prefill_logits)
@torch.inference_mode()
def synthetic_tokens(args):
t = AutoTokenizer.from_pretrained(args.model_path)
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf",
)
args = parser.parse_args()
normal_text(args)
# synthetic_tokens(args)

View File

@@ -0,0 +1,108 @@
import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
def test_generate_worker(model_path, tp_rank, tp_size):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# Input
prompts = [
"The capital of France is",
"Today is a sunny day and I like",
]
sampling_params = SamplingParams(temperature=0)
cut_num = 4
reqs = []
for i in range(len(prompts)):
req = Req(i)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params
reqs.append(req)
# Prefill
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
# Extend
for i in range(len(prompts)):
req = reqs[i]
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits", logits)
print(
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
)
# Decode
for i in range(6):
batch.update_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
print(
"next_token_ids",
next_token_ids,
[tokenizer.decode(x) for x in next_token_ids],
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
# Reference output for TinyLlama-1.1B-Chat-v0.4
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
# device='cuda:0', dtype=torch.float16)
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
# device='cuda:0', dtype=torch.float16)
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']

View File

@@ -0,0 +1,209 @@
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
def test_generate_worker(
model_path, tp_rank, tp_size, batch_size, input_len, output_len
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# Prepare data
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
input_ids = input_ids.reshape(-1)
input_ids = torch.tensor(input_ids).cuda()
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full(
(batch_size,), input_len, dtype=torch.int32, device="cuda"
)
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(print_logits):
nonlocal predict_ids
logits, _ = model.forward_prefill(
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
def decode(print_logits):
nonlocal predict_ids
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", i, logits)
# Warm up
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
predict_ids = None
prefill(True)
for i in range(output_len):
decode(True)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.free(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
)
model.req_to_token_pool.free(req_pool_indices)
# Benchmark
if tp_size > 1:
dist.barrier()
start_time = prefill_start_time = time.time()
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
prefill(False)
if tp_rank == 0:
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
for i in range(output_len):
step_start = time.time()
decode(False)
step_end = time.time()
if i % 100 == 0 or i == output_len - 1:
if tp_rank == 0:
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
end_time = time.time()
if tp_rank == 0:
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
batch_size,
input_len,
output_len,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
# device='cuda:0')
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
# device='cuda:0')
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
# device='cuda:0')
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
# device='cuda:0')
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
# device='cuda:0')
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
# device='cuda:0')
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
# device='cuda:0')
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
# device='cuda:0')
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
# device='cuda:0')
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
# device='cuda:0')
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
# device='cuda:0')
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
# device='cuda:0')
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
# device='cuda:0')
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
# device='cuda:0')
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
# device='cuda:0')
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
# device='cuda:0')
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
# device='cuda:0')

View File

@@ -0,0 +1,161 @@
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.utils import load_image
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda")
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(model, tp_rank, params, print_logits):
logits, _ = model.forward_extend_multi_modal(
*params,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
return predict_ids
def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits):
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = params
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", step, logits)
return predict_ids
def test_generate_worker(
model_path,
tp_rank,
tp_size,
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# print(model.model)
# Prepare data
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:"
image_path = "/home/ubuntu/sglang/test/lang/image.png"
image = load_image(image_path)
processor = get_processor("llava-hf/llava-1.5-7b-hf")
input_ids = processor.tokenizer.encode(prompt)
pixel_values = processor.image_processor(image)["pixel_values"]
input_ids, offset = model.model.pad_input_ids(
input_ids,
[
0,
],
)
params = init_batch_data(model, 1, len(input_ids))
# inference
output_ids = []
prefill_params = (
torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values),
[offset],
*params,
)
predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False)
output_ids.append(predict_ids[0][0])
for i in range(16):
predict_ids = decode(
i,
model,
tp_rank=0,
batch_size=1,
predict_ids=predict_ids,
params=params,
print_logits=False,
)
output_ids.append(predict_ids[0][0])
# detokenization
output = processor.tokenizer.batch_decode(
[output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
assert (
output
== "The image features a man standing on the back of a yellow taxi cab, holding"
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("liuhaotian/llava-v1.5-7b", 1)

163
test/srt/test_flashinfer.py Normal file
View File

@@ -0,0 +1,163 @@
import flashinfer
import pytest
import torch
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
@pytest.mark.parametrize("batch_size", [12, 37, 67])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("use_wrapper", [True, False])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
use_wrapper,
):
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
v_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
o_triton = torch.empty_like(q)
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
max_len_extend = qo_len
extend_attention_fwd(
q,
k_extend,
v_extend,
o_triton,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
None, # b_start_loc = None
b_seq_len,
None, # b_seq_len_prefix = None
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
if use_wrapper:
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper()
wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads)
o = wrapper.forward(
q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len
)
else:
o = flashinfer.batch_prefill_with_paged_kv_cache(
q,
q_indptr,
kv_data,
kv_indptr,
kv_indices,
kv_last_page_len,
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 37])
@pytest.mark.parametrize("kv_len", [54, 127, 537])
@pytest.mark.parametrize("num_kv_heads", [32])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
# to test different shape of decode, change the parameters in the __main__, and run decode only once
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
o_triton = torch.empty_like(q)
req_to_token = (
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
other_kv_index = 0
token_attention_fwd(
q,
k_buffer,
v_buffer,
o_triton,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_tokens,
)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper()
wrapper.begin_forward(
kv_indptr,
kv_last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
1,
"NONE",
"float16",
)
o = wrapper.forward(q, kv_data, kv_indptr, kv_indices, kv_last_page_len)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128, False)
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128, True)
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)

View File

@@ -0,0 +1,56 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def main(args):
url = f"{args.host}:{args.port}"
task1 = send_request(
url + "/generate",
{
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
delay=1,
)
task2 = send_request(
url + "/generate",
{
"text": "The capital of the United Kindom is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
)
rets = await asyncio.gather(task1, task2)
print(rets)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,31 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())

View File

@@ -0,0 +1,42 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")

View File

@@ -0,0 +1,84 @@
"""
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
Output:
The image features a man standing on the back of a yellow taxi cab, holding
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
response = []
for i in range(8):
response.append(
send_request(
url + "/generate",
{
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)

View File

@@ -0,0 +1,43 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is Paris.\nThe capital of the United States is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())