Files
sglang/test/srt/model/test_llama_low_api.py
Lianmin Zheng 22085081bb release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-01-08 04:37:50 +00:00

210 lines
7.1 KiB
Python

import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
def test_generate_worker(
model_path, tp_rank, tp_size, batch_size, input_len, output_len
):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
# Prepare data
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
input_ids = input_ids.reshape(-1)
input_ids = torch.tensor(input_ids).cuda()
def init_batch_data(model, batch_size, input_len):
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
seq_lens = torch.full(
(batch_size,), input_len, dtype=torch.int32, device="cuda"
)
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
i * input_len : (i + 1) * input_len
]
return (
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
)
def prefill(print_logits):
nonlocal predict_ids
logits, _ = model.forward_prefill(
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
False,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("prefill logits", logits, logits.shape)
def decode(print_logits):
nonlocal predict_ids
(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
seq_lens.add_(1)
logits = model.forward_decode(
torch.from_numpy(predict_ids).cuda().reshape(-1),
req_pool_indices,
seq_lens,
None,
position_ids_offsets,
None,
out_cache_cont_start,
out_cache_cont_end,
)
prob_out = torch.softmax(logits, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
if print_logits and tp_rank == 0:
print("decode", i, logits)
# Warm up
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
predict_ids = None
prefill(True)
for i in range(output_len):
decode(True)
for i in range(batch_size):
req_idx = req_pool_indices[i].item()
model.token_to_kv_pool.free(
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
)
model.req_to_token_pool.free(req_pool_indices)
# Benchmark
if tp_size > 1:
dist.barrier()
start_time = prefill_start_time = time.time()
(
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
) = init_batch_data(model, batch_size, input_len)
prefill(False)
if tp_rank == 0:
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
for i in range(output_len):
step_start = time.time()
decode(False)
step_end = time.time()
if i % 100 == 0 or i == output_len - 1:
if tp_rank == 0:
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
end_time = time.time()
if tp_rank == 0:
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
batch_size,
input_len,
output_len,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
# device='cuda:0')
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
# device='cuda:0')
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
# device='cuda:0')
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
# device='cuda:0')
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
# device='cuda:0')
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
# device='cuda:0')
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
# device='cuda:0')
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
# device='cuda:0')
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
# device='cuda:0')
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
# device='cuda:0')
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
# device='cuda:0')
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
# device='cuda:0')
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
# device='cuda:0')
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
# device='cuda:0')
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
# device='cuda:0')
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
# device='cuda:0')
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
# device='cuda:0')