Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
275 lines
8.3 KiB
Python
275 lines
8.3 KiB
Python
import multiprocessing as mp
|
|
import time
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from sglang.srt.managers.router.model_runner import ModelRunner
|
|
from sglang.srt.model_config import ModelConfig
|
|
|
|
|
|
@dataclass
|
|
class BenchBatch:
|
|
req_to_token_pool: torch.Tensor
|
|
token_to_kv_pool: torch.Tensor
|
|
|
|
input_ids: torch.Tensor = None
|
|
position_ids_offsets: torch.Tensor = None
|
|
seq_lens: torch.Tensor = None
|
|
prefix_lens: torch.Tensor = None
|
|
req_pool_indices: torch.Tensor = None
|
|
out_cache_loc: torch.Tensor = None
|
|
out_cache_cont_start: torch.Tensor = None
|
|
out_cache_cont_end: torch.Tensor = None
|
|
|
|
def __init__(self, model_runner: ModelRunner):
|
|
self.req_to_token_pool = model_runner.req_to_token_pool
|
|
self.token_to_kv_pool = model_runner.token_to_kv_pool
|
|
|
|
def init_prefill_batch(self, input_ids, batch_size, seq_len):
|
|
self.input_ids = input_ids
|
|
self.position_ids_offsets = torch.zeros(
|
|
batch_size, dtype=torch.int32, device="cuda"
|
|
)
|
|
self.seq_lens = torch.full(
|
|
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
|
|
)
|
|
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
|
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
|
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
|
|
|
|
for i in range(batch_size):
|
|
n_idx = self.req_pool_indices[i].item()
|
|
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
|
|
i * seq_len : (i + 1) * seq_len
|
|
]
|
|
|
|
def update_extend(
|
|
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
|
|
):
|
|
self.input_ids = input_ids
|
|
self.position_ids_offsets = torch.zeros(
|
|
batch_size, dtype=torch.int32, device="cuda"
|
|
)
|
|
self.seq_lens = torch.full(
|
|
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
|
|
)
|
|
self.prefix_lens = torch.full(
|
|
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
|
|
)
|
|
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
|
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
|
|
|
|
req_to_token = self.req_to_token_pool.req_to_token
|
|
fork_num = batch_size // prefix_req_idx.shape[0]
|
|
for i in range(batch_size):
|
|
p_idx = prefix_req_idx[i // fork_num].item()
|
|
n_idx = self.req_pool_indices[i].item()
|
|
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
|
req_to_token[
|
|
n_idx, prefix_len : prefix_len + extend_len
|
|
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
|
|
|
def update_decode(self, predict_ids, batch_size):
|
|
assert predict_ids.shape[0] == batch_size
|
|
assert batch_size == self.req_pool_indices.shape[0]
|
|
|
|
self.input_ids = predict_ids.reshape(-1)
|
|
self.prefix_lens = None
|
|
(
|
|
self.out_cache_loc,
|
|
self.out_cache_cont_start,
|
|
self.out_cache_cont_end,
|
|
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
|
self.req_to_token_pool.req_to_token[
|
|
self.req_pool_indices, self.seq_lens
|
|
] = self.out_cache_loc
|
|
self.seq_lens.add_(1)
|
|
|
|
|
|
def prefill(model_runner: ModelRunner, batch: BenchBatch):
|
|
logits, _ = model_runner.forward_extend(
|
|
batch.input_ids,
|
|
batch.req_pool_indices,
|
|
batch.seq_lens,
|
|
batch.prefix_lens,
|
|
batch.position_ids_offsets,
|
|
batch.out_cache_loc,
|
|
False,
|
|
)
|
|
|
|
prob_out = torch.softmax(logits, dim=-1)
|
|
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
|
predict_ids = predict_ids.detach().cpu().numpy()
|
|
|
|
return predict_ids
|
|
|
|
|
|
def extend(model_runner: ModelRunner, batch: BenchBatch):
|
|
logits, _ = model_runner.forward_extend(
|
|
batch.input_ids,
|
|
batch.req_pool_indices,
|
|
batch.seq_lens,
|
|
batch.prefix_lens,
|
|
batch.position_ids_offsets,
|
|
batch.out_cache_loc,
|
|
True,
|
|
)
|
|
|
|
prob_out = torch.softmax(logits, dim=-1)
|
|
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
|
predict_ids = predict_ids.detach().cpu().numpy()
|
|
|
|
return predict_ids
|
|
|
|
|
|
def decode(model_runner: ModelRunner, batch: BenchBatch):
|
|
logits = model_runner.forward_decode(
|
|
batch.input_ids,
|
|
batch.req_pool_indices,
|
|
batch.seq_lens,
|
|
None,
|
|
batch.position_ids_offsets,
|
|
None,
|
|
batch.out_cache_cont_start,
|
|
batch.out_cache_cont_end,
|
|
)
|
|
|
|
prob_out = torch.softmax(logits, dim=-1)
|
|
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
|
predict_ids = predict_ids.detach().cpu().numpy()
|
|
|
|
return predict_ids
|
|
|
|
|
|
def bench_generate_worker(
|
|
model_path,
|
|
tp_rank,
|
|
tp_size,
|
|
shared_num,
|
|
unique_num,
|
|
shared_len,
|
|
unique_len,
|
|
decode_len,
|
|
model_mode,
|
|
):
|
|
assert unique_num % shared_num == 0
|
|
|
|
model_config = ModelConfig(path=model_path)
|
|
model_runner = ModelRunner(
|
|
model_config=model_config,
|
|
mem_fraction_static=0.8,
|
|
tp_rank=tp_rank,
|
|
tp_size=tp_size,
|
|
nccl_port=28888,
|
|
model_mode=model_mode,
|
|
)
|
|
|
|
batch = BenchBatch(model_runner)
|
|
|
|
# warm up
|
|
for _ in range(1):
|
|
input_ids = torch.randint(
|
|
low=5, high=100, size=(shared_num * shared_len,)
|
|
).cuda()
|
|
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
|
_ = prefill(model_runner, batch)
|
|
|
|
input_ids = torch.randint(
|
|
low=5, high=100, size=(unique_num * unique_len,)
|
|
).cuda()
|
|
batch.update_extend(
|
|
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
|
)
|
|
predict_ids = extend(model_runner, batch)
|
|
|
|
for i in range(decode_len):
|
|
predict_ids = torch.from_numpy(predict_ids).cuda()
|
|
batch.update_decode(predict_ids, unique_num)
|
|
predict_ids = decode(model_runner, batch)
|
|
|
|
model_runner.req_to_token_pool.clear()
|
|
model_runner.token_to_kv_pool.clear()
|
|
|
|
if tp_size > 1:
|
|
dist.barrier()
|
|
|
|
prefill_start = time.time()
|
|
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
|
|
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
|
_ = prefill(model_runner, batch)
|
|
if tp_rank == 0:
|
|
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
|
|
|
|
extend_start = time.time()
|
|
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
|
|
batch.update_extend(
|
|
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
|
)
|
|
predict_ids = extend(model_runner, batch)
|
|
if tp_rank == 0:
|
|
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
|
|
|
|
for i in range(decode_len):
|
|
decode_start = time.time()
|
|
predict_ids = torch.from_numpy(predict_ids).cuda()
|
|
batch.update_decode(predict_ids, unique_num)
|
|
predict_ids = decode(model_runner, batch)
|
|
if tp_rank == 0:
|
|
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
|
|
|
|
|
|
def bench_generate(
|
|
model_path,
|
|
tp_size,
|
|
shared_num,
|
|
unique_num,
|
|
shared_len,
|
|
unique_len,
|
|
decode_len,
|
|
model_mode,
|
|
):
|
|
print(
|
|
f"tp_size: {tp_size}, "
|
|
f"shared_num: {shared_num}, "
|
|
f"unique_num: {unique_num}, "
|
|
f"shared_len: {shared_len}, "
|
|
f"unique_len: {unique_len}, "
|
|
f"decode_len: {decode_len}, "
|
|
f"model_mode: {model_mode}"
|
|
)
|
|
workers = []
|
|
for tp_rank in range(tp_size):
|
|
proc = mp.Process(
|
|
target=bench_generate_worker,
|
|
args=(
|
|
model_path,
|
|
tp_rank,
|
|
tp_size,
|
|
shared_num,
|
|
unique_num,
|
|
shared_len,
|
|
unique_len,
|
|
decode_len,
|
|
model_mode,
|
|
),
|
|
)
|
|
proc.start()
|
|
workers.append(proc)
|
|
|
|
for proc in workers:
|
|
proc.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
bench_generate(
|
|
model_path="meta-llama/Llama-2-7b-chat-hf",
|
|
tp_size=1,
|
|
shared_num=1,
|
|
unique_num=32,
|
|
shared_len=256,
|
|
unique_len=256,
|
|
decode_len=8,
|
|
model_mode=[],
|
|
)
|