Files
sglang/test/srt/model/test_llama_extend.py

109 lines
3.9 KiB
Python

import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
import transformers
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig
from sglang.srt.sampling_params import SamplingParams
def test_generate_worker(model_path, tp_rank, tp_size):
model_config = ModelConfig(path=model_path)
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# Input
prompts = [
"The capital of France is",
"Today is a sunny day and I like",
]
sampling_params = SamplingParams(temperature=0)
cut_num = 4
reqs = []
for i in range(len(prompts)):
req = Req(i, None, None)
req.input_ids = tokenizer.encode(prompts[i])[:cut_num]
req.sampling_params = sampling_params
reqs.append(req)
# Prefill
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
# Extend
for i in range(len(prompts)):
req = reqs[i]
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits", logits)
print(
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
)
# Decode
for i in range(6):
batch.prepare_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
print(
"next_token_ids",
next_token_ids,
[tokenizer.decode(x) for x in next_token_ids],
)
def test_generate(model_path, tp_size):
workers = []
for tp_rank in range(tp_size):
proc = multiprocessing.Process(
target=test_generate_worker,
args=(
model_path,
tp_rank,
tp_size,
),
)
proc.start()
workers.append(proc)
for proc in workers:
proc.join()
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
# Reference output for TinyLlama-1.1B-Chat-v0.4
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
# device='cuda:0', dtype=torch.float16)
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
# device='cuda:0', dtype=torch.float16)
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']