release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

View File

@@ -0,0 +1,274 @@
import multiprocessing as mp
import time
from dataclasses import dataclass
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
@dataclass
class BenchBatch:
req_to_token_pool: torch.Tensor
token_to_kv_pool: torch.Tensor
input_ids: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
req_pool_indices: torch.Tensor = None
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
def __init__(self, model_runner: ModelRunner):
self.req_to_token_pool = model_runner.req_to_token_pool
self.token_to_kv_pool = model_runner.token_to_kv_pool
def init_prefill_batch(self, input_ids, batch_size, seq_len):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
for i in range(batch_size):
n_idx = self.req_pool_indices[i].item()
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
i * seq_len : (i + 1) * seq_len
]
def update_extend(
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
):
self.input_ids = input_ids
self.position_ids_offsets = torch.zeros(
batch_size, dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.full(
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
)
self.prefix_lens = torch.full(
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
)
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
req_to_token = self.req_to_token_pool.req_to_token
fork_num = batch_size // prefix_req_idx.shape[0]
for i in range(batch_size):
p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[
n_idx, prefix_len : prefix_len + extend_len
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size
assert batch_size == self.req_pool_indices.shape[0]
self.input_ids = predict_ids.reshape(-1)
self.prefix_lens = None
(
self.out_cache_loc,
self.out_cache_cont_start,
self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens
] = self.out_cache_loc
self.seq_lens.add_(1)
def prefill(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def extend(model_runner: ModelRunner, batch: BenchBatch):
logits, _ = model_runner.forward_extend(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
batch.prefix_lens,
batch.position_ids_offsets,
batch.out_cache_loc,
True,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def decode(model_runner: ModelRunner, batch: BenchBatch):
logits = model_runner.forward_decode(
batch.input_ids,
batch.req_pool_indices,
batch.seq_lens,
None,
batch.position_ids_offsets,
None,
batch.out_cache_cont_start,
batch.out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
return predict_ids
def bench_generate_worker(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
):
assert unique_num % shared_num == 0
model_config = ModelConfig(path=model_path)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=0.8,
tp_rank=tp_rank,
tp_size=tp_size,
nccl_port=28888,
model_mode=model_mode,
)
batch = BenchBatch(model_runner)
# warm up
for _ in range(1):
input_ids = torch.randint(
low=5, high=100, size=(shared_num * shared_len,)
).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
input_ids = torch.randint(
low=5, high=100, size=(unique_num * unique_len,)
).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
for i in range(decode_len):
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
if tp_size > 1:
dist.barrier()
prefill_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
batch.init_prefill_batch(input_ids, shared_num, shared_len)
_ = prefill(model_runner, batch)
if tp_rank == 0:
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
extend_start = time.time()
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
batch.update_extend(
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
)
predict_ids = extend(model_runner, batch)
if tp_rank == 0:
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
for i in range(decode_len):
decode_start = time.time()
predict_ids = torch.from_numpy(predict_ids).cuda()
batch.update_decode(predict_ids, unique_num)
predict_ids = decode(model_runner, batch)
if tp_rank == 0:
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
def bench_generate(
model_path,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
):
print(
f"tp_size: {tp_size}, "
f"shared_num: {shared_num}, "
f"unique_num: {unique_num}, "
f"shared_len: {shared_len}, "
f"unique_len: {unique_len}, "
f"decode_len: {decode_len}, "
f"model_mode: {model_mode}"
)
workers = []
for tp_rank in range(tp_size):
proc = mp.Process(
target=bench_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
shared_num,
unique_num,
shared_len,
unique_len,
decode_len,
model_mode,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
bench_generate(
model_path="meta-llama/Llama-2-7b-chat-hf",
tp_size=1,
shared_num=1,
unique_num=32,
shared_len=256,
unique_len=256,
decode_len=8,
model_mode=[],
)

View File

@@ -0,0 +1,80 @@
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@torch.inference_mode()
def normal_text(args):
t = AutoTokenizer.from_pretrained(args.model_path)
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = 32
for p in prompts:
if isinstance(p, str):
input_ids = t.encode(p, return_tensors="pt").cuda()
else:
input_ids = torch.tensor([p], device="cuda")
output_ids = m.generate(
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
print(output_str)
prefill_logits = m.forward(input_ids).logits[0][-1]
print("prefill logits", prefill_logits)
@torch.inference_mode()
def synthetic_tokens(args):
t = AutoTokenizer.from_pretrained(args.model_path)
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
m.cuda()
print(m)
input_len = 256
output_len = 8
prompts = [list(range(5, 5 + input_len))]
for p in prompts:
input_ids = p
for i in range(output_len + 1):
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
0
][-1]
if i == 0:
print("prefill logits", prefill_logits)
else:
print("decode", i - 1, prefill_logits)
input_ids.append(torch.argmax(prefill_logits).item())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf",
)
args = parser.parse_args()
normal_text(args)
# synthetic_tokens(args)

View File

@@ -0,0 +1,108 @@
import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
def test_generate_worker(model_path, tp_rank, tp_size):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# Input
prompts = [
"The capital of France is",
"Today is a sunny day and I like",
]
sampling_params = SamplingParams(temperature=0)
cut_num = 4
reqs = []
for i in range(len(prompts)):
req = Req(i)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params
reqs.append(req)
# Prefill
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
# Extend
for i in range(len(prompts)):
req = reqs[i]
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits", logits)
print(
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
)
# Decode
for i in range(6):
batch.update_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
print(
"next_token_ids",
next_token_ids,
[tokenizer.decode(x) for x in next_token_ids],
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
# Reference output for TinyLlama-1.1B-Chat-v0.4
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
# device='cuda:0', dtype=torch.float16)
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
# device='cuda:0', dtype=torch.float16)
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']

View File

@@ -0,0 +1,209 @@
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
def test_generate_worker(
model_path, tp_rank, tp_size, batch_size, input_len, output_len
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# Prepare data
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
input_ids = input_ids.reshape(-1)
input_ids = torch.tensor(input_ids).cuda()
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full(
(batch_size,), input_len, dtype=torch.int32, device="cuda"
)
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(print_logits):
nonlocal predict_ids
logits, _ = model.forward_prefill(
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
def decode(print_logits):
nonlocal predict_ids
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", i, logits)
# Warm up
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
predict_ids = None
prefill(True)
for i in range(output_len):
decode(True)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.free(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
)
model.req_to_token_pool.free(req_pool_indices)
# Benchmark
if tp_size > 1:
dist.barrier()
start_time = prefill_start_time = time.time()
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
prefill(False)
if tp_rank == 0:
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
for i in range(output_len):
step_start = time.time()
decode(False)
step_end = time.time()
if i % 100 == 0 or i == output_len - 1:
if tp_rank == 0:
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
end_time = time.time()
if tp_rank == 0:
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
batch_size,
input_len,
output_len,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
# device='cuda:0')
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
# device='cuda:0')
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
# device='cuda:0')
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
# device='cuda:0')
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
# device='cuda:0')
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
# device='cuda:0')
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
# device='cuda:0')
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
# device='cuda:0')
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
# device='cuda:0')
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
# device='cuda:0')
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
# device='cuda:0')
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
# device='cuda:0')
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
# device='cuda:0')
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
# device='cuda:0')
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
# device='cuda:0')
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
# device='cuda:0')
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
# device='cuda:0')

View File

@@ -0,0 +1,161 @@
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.utils import load_image
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda")
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(model, tp_rank, params, print_logits):
logits, _ = model.forward_extend_multi_modal(
*params,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
return predict_ids
def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits):
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = params
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", step, logits)
return predict_ids
def test_generate_worker(
model_path,
tp_rank,
tp_size,
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# print(model.model)
# Prepare data
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:"
image_path = "/home/ubuntu/sglang/test/lang/image.png"
image = load_image(image_path)
processor = get_processor("llava-hf/llava-1.5-7b-hf")
input_ids = processor.tokenizer.encode(prompt)
pixel_values = processor.image_processor(image)["pixel_values"]
input_ids, offset = model.model.pad_input_ids(
input_ids,
[
0,
],
)
params = init_batch_data(model, 1, len(input_ids))
# inference
output_ids = []
prefill_params = (
torch.tensor(np.array(input_ids)).cuda(),
np.array(pixel_values),
[offset],
*params,
)
predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False)
output_ids.append(predict_ids[0][0])
for i in range(16):
predict_ids = decode(
i,
model,
tp_rank=0,
batch_size=1,
predict_ids=predict_ids,
params=params,
print_logits=False,
)
output_ids.append(predict_ids[0][0])
# detokenization
output = processor.tokenizer.batch_decode(
[output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
assert (
output
== "The image features a man standing on the back of a yellow taxi cab, holding"
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("liuhaotian/llava-v1.5-7b", 1)

163
test/srt/test_flashinfer.py Normal file
View File

@@ -0,0 +1,163 @@
import flashinfer
import pytest
import torch
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
@pytest.mark.parametrize("batch_size", [12, 37, 67])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("use_wrapper", [True, False])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
use_wrapper,
):
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
v_extend = (
kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1]
.contiguous()
.view(-1, num_kv_heads, head_dim)
)
o_triton = torch.empty_like(q)
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len
b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
max_len_extend = qo_len
extend_attention_fwd(
q,
k_extend,
v_extend,
o_triton,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
None, # b_start_loc = None
b_seq_len,
None, # b_seq_len_prefix = None
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
if use_wrapper:
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper()
wrapper.begin_forward(q_indptr, batch_size, num_qo_heads, num_kv_heads)
o = wrapper.forward(
q, q_indptr, kv_data, kv_indptr, kv_indices, kv_last_page_len
)
else:
o = flashinfer.batch_prefill_with_paged_kv_cache(
q,
q_indptr,
kv_data,
kv_indptr,
kv_indices,
kv_last_page_len,
)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 37])
@pytest.mark.parametrize("kv_len", [54, 127, 537])
@pytest.mark.parametrize("num_kv_heads", [32])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
num_kv_heads,
num_qo_heads,
head_dim,
):
# note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache
# to test different shape of decode, change the parameters in the __main__, and run decode only once
q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half()
total_tokens = kv_len * batch_size
kv_data = torch.randn(total_tokens, 2, num_kv_heads, 1, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
# init args for triton kernel
k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous()
v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous()
o_triton = torch.empty_like(q)
req_to_token = (
torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len)
)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
max_len_in_batch = kv_len
other_kv_index = 0
token_attention_fwd(
q,
k_buffer,
v_buffer,
o_triton,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
other_kv_index,
total_tokens,
)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper()
wrapper.begin_forward(
kv_indptr,
kv_last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
1,
"NONE",
"float16",
)
o = wrapper.forward(q, kv_data, kv_indptr, kv_indices, kv_last_page_len)
print("Mean: ", torch.mean(torch.abs(o - o_triton)))
print("Max: ", torch.max(torch.abs(o - o_triton)))
assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3)
if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128, False)
test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128, True)
test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)

View File

@@ -0,0 +1,56 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.
The capital of the United Kindom is London.\nThe capital of the United Kingdom is London.\nThe capital of
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def main(args):
url = f"{args.host}:{args.port}"
task1 = send_request(
url + "/generate",
{
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
delay=1,
)
task2 = send_request(
url + "/generate",
{
"text": "The capital of the United Kindom is",
"sampling_params": {"temperature": 0, "max_new_tokens": 128},
},
)
rets = await asyncio.gather(task1, task2)
print(rets)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(main(args))

View File

@@ -0,0 +1,31 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())

View File

@@ -0,0 +1,42 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import json
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 1024,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")

View File

@@ -0,0 +1,84 @@
"""
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
Output:
The image features a man standing on the back of a yellow taxi cab, holding
"""
import argparse
import asyncio
import json
import time
import aiohttp
import requests
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
response = []
for i in range(8):
response.append(
send_request(
url + "/generate",
{
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:",
"image_data": "/home/ubuntu/sglang/test/lang/image.png",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
},
"stream": True,
},
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)

View File

@@ -0,0 +1,43 @@
"""
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000
Output:
The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo
"""
import argparse
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is Paris.\nThe capital of the United States is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
print(response.json())