release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
274
test/srt/model/bench_llama_low_api.py
Normal file
274
test/srt/model/bench_llama_low_api.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sglang.srt.managers.router.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchBatch:
|
||||
req_to_token_pool: torch.Tensor
|
||||
token_to_kv_pool: torch.Tensor
|
||||
|
||||
input_ids: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
self.req_to_token_pool = model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
def init_prefill_batch(self, input_ids, batch_size, seq_len):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
|
||||
|
||||
for i in range(batch_size):
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
|
||||
i * seq_len : (i + 1) * seq_len
|
||||
]
|
||||
|
||||
def update_extend(
|
||||
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.full(
|
||||
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
|
||||
|
||||
req_to_token = self.req_to_token_pool.req_to_token
|
||||
fork_num = batch_size // prefix_req_idx.shape[0]
|
||||
for i in range(batch_size):
|
||||
p_idx = prefix_req_idx[i // fork_num].item()
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||
req_to_token[
|
||||
n_idx, prefix_len : prefix_len + extend_len
|
||||
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||
|
||||
def update_decode(self, predict_ids, batch_size):
|
||||
assert predict_ids.shape[0] == batch_size
|
||||
assert batch_size == self.req_pool_indices.shape[0]
|
||||
|
||||
self.input_ids = predict_ids.reshape(-1)
|
||||
self.prefix_lens = None
|
||||
(
|
||||
self.out_cache_loc,
|
||||
self.out_cache_cont_start,
|
||||
self.out_cache_cont_end,
|
||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
self.req_pool_indices, self.seq_lens
|
||||
] = self.out_cache_loc
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
|
||||
def prefill(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
False,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def extend(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
True,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def decode(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits = model_runner.forward_decode(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
None,
|
||||
batch.position_ids_offsets,
|
||||
None,
|
||||
batch.out_cache_cont_start,
|
||||
batch.out_cache_cont_end,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def bench_generate_worker(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
):
|
||||
assert unique_num % shared_num == 0
|
||||
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=0.8,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
nccl_port=28888,
|
||||
model_mode=model_mode,
|
||||
)
|
||||
|
||||
batch = BenchBatch(model_runner)
|
||||
|
||||
# warm up
|
||||
for _ in range(1):
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(shared_num * shared_len,)
|
||||
).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(unique_num * unique_len,)
|
||||
).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
|
||||
for i in range(decode_len):
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
if tp_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
prefill_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
|
||||
|
||||
extend_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
|
||||
|
||||
for i in range(decode_len):
|
||||
decode_start = time.time()
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
|
||||
|
||||
|
||||
def bench_generate(
|
||||
model_path,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
):
|
||||
print(
|
||||
f"tp_size: {tp_size}, "
|
||||
f"shared_num: {shared_num}, "
|
||||
f"unique_num: {unique_num}, "
|
||||
f"shared_len: {shared_len}, "
|
||||
f"unique_len: {unique_len}, "
|
||||
f"decode_len: {decode_len}, "
|
||||
f"model_mode: {model_mode}"
|
||||
)
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = mp.Process(
|
||||
target=bench_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
model_mode,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_generate(
|
||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||
tp_size=1,
|
||||
shared_num=1,
|
||||
unique_num=32,
|
||||
shared_len=256,
|
||||
unique_len=256,
|
||||
decode_len=8,
|
||||
model_mode=[],
|
||||
)
|
||||
Reference in New Issue
Block a user