Add sglang.bench_latency for offline benchmark (#564)
This commit is contained in:
@@ -1,5 +1,28 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4
|
||||
|
||||
Reference output:
|
||||
<s> The capital of France is Paris.
|
||||
The capital of the United States is Washington, D.C.
|
||||
The capital of Canada is Ottawa.
|
||||
The capital of Japan is Tokyo
|
||||
prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141],
|
||||
device='cuda:0')
|
||||
<s> The capital of the United Kindom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the United Kingdom is London.
|
||||
prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742],
|
||||
device='cuda:0')
|
||||
<s> Today is a sunny day and I like to go for a walk in the park.
|
||||
I'm going to the park to play in the grass and water.
|
||||
Today is a very
|
||||
prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609],
|
||||
device='cuda:0')
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -40,7 +63,6 @@ def normal_text(args):
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthetic_tokens(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
288
python/sglang/bench_latency.py
Normal file
288
python/sglang/bench_latency.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
|
||||
|
||||
# Usage (latency test):
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
|
||||
|
||||
# Usage (correctness test):
|
||||
python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
||||
|
||||
### Reference output:
|
||||
prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
|
||||
[ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
|
||||
[-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
|
||||
[-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
|
||||
device='cuda:0', dtype=torch.float16)
|
||||
<s> The capital of France is.
|
||||
The capital of the United States is Washington, D.C.
|
||||
|
||||
<s> The capital of the United Kindom is.
|
||||
The capital of the United Kingdom is London.
|
||||
The capital of the
|
||||
<s> Today is a sunny day and I like go for a walk in the park.
|
||||
I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import suppress_other_loggers
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
batch_size: int = 1
|
||||
input_len: int = 1024
|
||||
output_len: int = 4
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
||||
parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
|
||||
parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
|
||||
parser.add_argument("--correctness-test", action="store_true")
|
||||
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
def load_model(server_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
|
||||
model_config = ModelConfig(path=server_args.model_path)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=server_args.mem_fraction_static,
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=28888,
|
||||
server_args=server_args,
|
||||
)
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
if server_args.tp_size > 1:
|
||||
dist.barrier()
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs(bench_args, tokenizer):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
||||
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
req.input_ids = req.origin_input_ids
|
||||
reqs.append(req)
|
||||
|
||||
return input_ids, reqs
|
||||
|
||||
|
||||
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len:]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, :bench_args.cut_len
|
||||
]
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs(bench_args, tokenizer):
|
||||
input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
)
|
||||
|
||||
reqs = []
|
||||
for i in range(len(input_ids)):
|
||||
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
req.input_ids = req.origin_input_ids
|
||||
reqs.append(req)
|
||||
|
||||
return reqs
|
||||
|
||||
|
||||
def extend(reqs, model_runner):
|
||||
batch = Batch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids.cpu().numpy())
|
||||
output = model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
return next_token_ids, output.next_token_logits
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs(bench_args, tokenizer)
|
||||
|
||||
# Prefill
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (first half)", next_token_logits)
|
||||
|
||||
# Prepare extend inputs
|
||||
reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
|
||||
|
||||
# Extend
|
||||
next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
|
||||
rank_print("prefill logits (final)", next_token_logits)
|
||||
|
||||
# Decode
|
||||
output_ids = [list(req.input_ids) for req in reqs]
|
||||
for _ in range(bench_args.output_len):
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
for i in range(len(reqs)):
|
||||
output_ids[i].append(next_token_ids[i])
|
||||
|
||||
# Print
|
||||
for i in range(len(reqs)):
|
||||
print(tokenizer.decode(output_ids[i]))
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
):
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
||||
|
||||
def clear():
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_once(output_len):
|
||||
# Prefill
|
||||
torch.cuda.synchronize()
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
latency = time.time() - tic
|
||||
throughput = bench_args.input_len * bench_args.batch_size / latency
|
||||
rank_print(f"Prefill. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
|
||||
|
||||
# Decode
|
||||
for _ in range(output_len):
|
||||
torch.cuda.synchronize()
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
latency = time.time() - tic
|
||||
throughput = bench_args.batch_size / latency
|
||||
rank_print(f"Decode. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
clear()
|
||||
|
||||
# Run again
|
||||
run_once(bench_args.output_len)
|
||||
|
||||
|
||||
def main(server_args, bench_args):
|
||||
print(bench_args)
|
||||
|
||||
if bench_args.correctness_test:
|
||||
work_func = correctness_test
|
||||
else:
|
||||
work_func = latency_test
|
||||
|
||||
workers = []
|
||||
for tp_rank in range(server_args.tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=work_func,
|
||||
args=(
|
||||
server_args,
|
||||
bench_args,
|
||||
tp_rank,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
bench_args = BenchArgs.from_cli_args(args)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, server_args.log_level.upper()),
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
main(server_args, bench_args)
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
is_multimodal_model,
|
||||
monkey_patch_vllm_dummy_weight_loader,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
)
|
||||
|
||||
@@ -229,6 +230,7 @@ class ModelRunner:
|
||||
self.nccl_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
monkey_patch_vllm_dummy_weight_loader()
|
||||
|
||||
# Init torch distributed
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||
|
||||
@@ -466,6 +466,48 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||
|
||||
|
||||
def monkey_patch_vllm_dummy_weight_loader():
|
||||
"""
|
||||
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
|
||||
"""
|
||||
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
|
||||
ParallelConfig, SchedulerConfig, CacheConfig, nn,
|
||||
set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
|
||||
DummyModelLoader
|
||||
)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
cache_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
return model.eval()
|
||||
|
||||
setattr(DummyModelLoader, "load_model", load_model)
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchBatch:
|
||||
req_to_token_pool: torch.Tensor
|
||||
token_to_kv_pool: torch.Tensor
|
||||
|
||||
input_ids: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
prefix_lens: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
out_cache_cont_start: torch.Tensor = None
|
||||
out_cache_cont_end: torch.Tensor = None
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
self.req_to_token_pool = model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
|
||||
def init_prefill_batch(self, input_ids, batch_size, seq_len):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), seq_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len)
|
||||
|
||||
for i in range(batch_size):
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[
|
||||
i * seq_len : (i + 1) * seq_len
|
||||
]
|
||||
|
||||
def update_extend(
|
||||
self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx
|
||||
):
|
||||
self.input_ids = input_ids
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.full(
|
||||
(batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.prefix_lens = torch.full(
|
||||
(batch_size,), prefix_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.req_pool_indices = self.req_to_token_pool.alloc(batch_size)
|
||||
self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len)
|
||||
|
||||
req_to_token = self.req_to_token_pool.req_to_token
|
||||
fork_num = batch_size // prefix_req_idx.shape[0]
|
||||
for i in range(batch_size):
|
||||
p_idx = prefix_req_idx[i // fork_num].item()
|
||||
n_idx = self.req_pool_indices[i].item()
|
||||
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
|
||||
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
|
||||
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
|
||||
)
|
||||
|
||||
def update_decode(self, predict_ids, batch_size):
|
||||
assert predict_ids.shape[0] == batch_size
|
||||
assert batch_size == self.req_pool_indices.shape[0]
|
||||
|
||||
self.input_ids = predict_ids.reshape(-1)
|
||||
self.prefix_lens = None
|
||||
(
|
||||
self.out_cache_loc,
|
||||
self.out_cache_cont_start,
|
||||
self.out_cache_cont_end,
|
||||
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||
self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
|
||||
def prefill(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
False,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def extend(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits, _ = model_runner.forward_extend(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
batch.prefix_lens,
|
||||
batch.position_ids_offsets,
|
||||
batch.out_cache_loc,
|
||||
True,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def decode(model_runner: ModelRunner, batch: BenchBatch):
|
||||
logits = model_runner.forward_decode(
|
||||
batch.input_ids,
|
||||
batch.req_pool_indices,
|
||||
batch.seq_lens,
|
||||
None,
|
||||
batch.position_ids_offsets,
|
||||
None,
|
||||
batch.out_cache_cont_start,
|
||||
batch.out_cache_cont_end,
|
||||
)
|
||||
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def bench_generate_worker(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
server_args_dict,
|
||||
):
|
||||
assert unique_num % shared_num == 0
|
||||
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model_runner = ModelRunner(
|
||||
model_config=model_config,
|
||||
mem_fraction_static=0.8,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
nccl_port=28888,
|
||||
server_args_dict=server_args_dict,
|
||||
)
|
||||
|
||||
batch = BenchBatch(model_runner)
|
||||
|
||||
# warm up
|
||||
for _ in range(1):
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(shared_num * shared_len,)
|
||||
).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=5, high=100, size=(unique_num * unique_len,)
|
||||
).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
|
||||
for i in range(decode_len):
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
if tp_size > 1:
|
||||
dist.barrier()
|
||||
|
||||
prefill_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda()
|
||||
batch.init_prefill_batch(input_ids, shared_num, shared_len)
|
||||
_ = prefill(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms")
|
||||
|
||||
extend_start = time.time()
|
||||
input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda()
|
||||
batch.update_extend(
|
||||
input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices
|
||||
)
|
||||
predict_ids = extend(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms")
|
||||
|
||||
for i in range(decode_len):
|
||||
decode_start = time.time()
|
||||
predict_ids = torch.from_numpy(predict_ids).cuda()
|
||||
batch.update_decode(predict_ids, unique_num)
|
||||
predict_ids = decode(model_runner, batch)
|
||||
if tp_rank == 0:
|
||||
print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms")
|
||||
|
||||
|
||||
def bench_generate(
|
||||
model_path,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
server_args_dict,
|
||||
):
|
||||
print(
|
||||
f"tp_size: {tp_size}, "
|
||||
f"shared_num: {shared_num}, "
|
||||
f"unique_num: {unique_num}, "
|
||||
f"shared_len: {shared_len}, "
|
||||
f"unique_len: {unique_len}, "
|
||||
f"decode_len: {decode_len}, "
|
||||
f"server_args: {server_args_dict}"
|
||||
)
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = mp.Process(
|
||||
target=bench_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
shared_num,
|
||||
unique_num,
|
||||
shared_len,
|
||||
unique_len,
|
||||
decode_len,
|
||||
server_args_dict,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_generate(
|
||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||
tp_size=1,
|
||||
shared_num=1,
|
||||
unique_num=32,
|
||||
shared_len=256,
|
||||
unique_len=256,
|
||||
decode_len=8,
|
||||
server_args_dict={},
|
||||
)
|
||||
@@ -1,109 +0,0 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
|
||||
|
||||
def test_generate_worker(model_path, tp_rank, tp_size):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Input
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
cut_num = 4
|
||||
|
||||
reqs = []
|
||||
for i in range(len(prompts)):
|
||||
input_ids = tokenizer.encode(prompts[i])[:cut_num]
|
||||
req = Req(i, prompts[i], input_ids)
|
||||
req.sampling_params = sampling_params
|
||||
reqs.append(req)
|
||||
|
||||
# Prefill
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
print("extend logits (first)", logits)
|
||||
|
||||
# Extend
|
||||
for i in range(len(prompts)):
|
||||
req = reqs[i]
|
||||
req.input_ids += tokenizer.encode(prompts[i])[cut_num:]
|
||||
req.prefix_indices = model.req_to_token_pool.req_to_token[
|
||||
batch.req_pool_indices[i], :cut_num
|
||||
]
|
||||
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
|
||||
batch.prepare_for_extend(model.model_config.vocab_size, None)
|
||||
logits, _ = model.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
print("extend logits", logits)
|
||||
print(
|
||||
"next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids]
|
||||
)
|
||||
|
||||
# Decode
|
||||
for i in range(6):
|
||||
batch.prepare_for_decode(next_token_ids.cpu().numpy())
|
||||
logits, _ = model.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids, next_token_probs = batch.sample(logits)
|
||||
|
||||
print(
|
||||
"next_token_ids",
|
||||
next_token_ids,
|
||||
[tokenizer.decode(x) for x in next_token_ids],
|
||||
)
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1)
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4
|
||||
# extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633],
|
||||
# [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]],
|
||||
# device='cuda:0', dtype=torch.float16)
|
||||
# extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121],
|
||||
# [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]],
|
||||
# device='cuda:0', dtype=torch.float16)
|
||||
# next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to']
|
||||
# next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go']
|
||||
# next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for']
|
||||
# next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a']
|
||||
# next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk']
|
||||
# next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in']
|
||||
# next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the']
|
||||
@@ -1,211 +0,0 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
|
||||
def test_generate_worker(
|
||||
model_path, tp_rank, tp_size, batch_size, input_len, output_len
|
||||
):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
|
||||
# Prepare data
|
||||
input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)])
|
||||
input_ids = input_ids.reshape(-1)
|
||||
input_ids = torch.tensor(input_ids).cuda()
|
||||
|
||||
def init_batch_data(model, batch_size, input_len):
|
||||
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
|
||||
seq_lens = torch.full(
|
||||
(batch_size,), input_len, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
|
||||
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
|
||||
for i in range(batch_size):
|
||||
req_idx = req_pool_indices[i].item()
|
||||
model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[
|
||||
i * input_len : (i + 1) * input_len
|
||||
]
|
||||
|
||||
return (
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
def prefill(print_logits):
|
||||
nonlocal predict_ids
|
||||
|
||||
logits, _ = model.forward_prefill(
|
||||
input_ids,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
if print_logits and tp_rank == 0:
|
||||
print("prefill logits", logits, logits.shape)
|
||||
|
||||
def decode(print_logits):
|
||||
nonlocal predict_ids
|
||||
|
||||
(
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
|
||||
seq_lens.add_(1)
|
||||
logits, _ = model.forward_decode(
|
||||
torch.from_numpy(predict_ids).cuda().reshape(-1),
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
position_ids_offsets,
|
||||
None,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
if print_logits and tp_rank == 0:
|
||||
print("decode", i, logits)
|
||||
|
||||
# Warm up
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = init_batch_data(model, batch_size, input_len)
|
||||
predict_ids = None
|
||||
|
||||
prefill(True)
|
||||
for i in range(output_len):
|
||||
decode(True)
|
||||
|
||||
for i in range(batch_size):
|
||||
req_idx = req_pool_indices[i].item()
|
||||
model.token_to_kv_pool.dec_refs(
|
||||
model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]]
|
||||
)
|
||||
model.req_to_token_pool.free(req_pool_indices)
|
||||
|
||||
# Benchmark
|
||||
if tp_size > 1:
|
||||
dist.barrier()
|
||||
start_time = prefill_start_time = time.time()
|
||||
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = init_batch_data(model, batch_size, input_len)
|
||||
|
||||
prefill(False)
|
||||
|
||||
if tp_rank == 0:
|
||||
print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms")
|
||||
|
||||
for i in range(output_len):
|
||||
step_start = time.time()
|
||||
|
||||
decode(False)
|
||||
|
||||
step_end = time.time()
|
||||
|
||||
if i % 100 == 0 or i == output_len - 1:
|
||||
if tp_rank == 0:
|
||||
print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
if tp_rank == 0:
|
||||
print(f"total cost: {(end_time - start_time) * 1000:.2f}")
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size, batch_size, input_len, output_len):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
batch_size,
|
||||
input_len,
|
||||
output_len,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8)
|
||||
# test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8)
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8)
|
||||
# prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00,
|
||||
# 1.8281e+00, 2.1816e+00]], device='cuda:0')
|
||||
# decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]],
|
||||
# device='cuda:0')
|
||||
# decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]],
|
||||
# device='cuda:0')
|
||||
# decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]],
|
||||
# device='cuda:0')
|
||||
# decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]],
|
||||
# device='cuda:0')
|
||||
# decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]],
|
||||
# device='cuda:0')
|
||||
# decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]],
|
||||
# device='cuda:0')
|
||||
# decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]],
|
||||
# device='cuda:0')
|
||||
# decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]],
|
||||
# device='cuda:0')
|
||||
|
||||
# Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8)
|
||||
# prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]],
|
||||
# device='cuda:0')
|
||||
# decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]],
|
||||
# device='cuda:0')
|
||||
# decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]],
|
||||
# device='cuda:0')
|
||||
# decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]],
|
||||
# device='cuda:0')
|
||||
# decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]],
|
||||
# device='cuda:0')
|
||||
# decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]],
|
||||
# device='cuda:0')
|
||||
# decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]],
|
||||
# device='cuda:0')
|
||||
# decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]],
|
||||
# device='cuda:0')
|
||||
# decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]],
|
||||
# device='cuda:0')
|
||||
@@ -1,164 +0,0 @@
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
def init_batch_data(model, batch_size, input_len):
|
||||
req_pool_indices = model.req_to_token_pool.alloc(batch_size)
|
||||
seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda")
|
||||
prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
|
||||
|
||||
out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len)
|
||||
for i in range(batch_size):
|
||||
model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[
|
||||
i * input_len : (i + 1) * input_len
|
||||
]
|
||||
|
||||
return (
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
)
|
||||
|
||||
|
||||
def prefill(model, tp_rank, params, print_logits):
|
||||
logits, _ = model.forward_extend_multi_modal(
|
||||
*params,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
|
||||
if print_logits and tp_rank == 0:
|
||||
print("prefill logits", logits, logits.shape)
|
||||
|
||||
return predict_ids
|
||||
|
||||
|
||||
def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits):
|
||||
(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
position_ids_offsets,
|
||||
out_cache_loc,
|
||||
) = params
|
||||
|
||||
(
|
||||
out_cache_loc,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
) = model.token_to_kv_pool.alloc_contiguous(batch_size)
|
||||
model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc
|
||||
seq_lens.add_(1)
|
||||
logits, _ = model.forward_decode(
|
||||
torch.from_numpy(predict_ids).cuda().reshape(-1),
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
None,
|
||||
position_ids_offsets,
|
||||
None,
|
||||
out_cache_cont_start,
|
||||
out_cache_cont_end,
|
||||
False,
|
||||
)
|
||||
prob_out = torch.softmax(logits, dim=-1)
|
||||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
|
||||
predict_ids = predict_ids.detach().cpu().numpy()
|
||||
if print_logits and tp_rank == 0:
|
||||
print("decode", step, logits)
|
||||
return predict_ids
|
||||
|
||||
|
||||
def test_generate_worker(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
):
|
||||
model_config = ModelConfig(path=model_path)
|
||||
model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888)
|
||||
# print(model.model)
|
||||
|
||||
# Prepare data
|
||||
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nDescribe this picture ASSISTANT:"
|
||||
image_path = "/home/ubuntu/sglang/test/lang/test_image.png"
|
||||
image = load_image(image_path)
|
||||
|
||||
processor = get_processor("llava-hf/llava-1.5-7b-hf")
|
||||
input_ids = processor.tokenizer.encode(prompt)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
input_ids, offset = model.model.pad_input_ids(
|
||||
input_ids,
|
||||
[
|
||||
0,
|
||||
],
|
||||
)
|
||||
|
||||
params = init_batch_data(model, 1, len(input_ids))
|
||||
|
||||
# inference
|
||||
output_ids = []
|
||||
prefill_params = (
|
||||
torch.tensor(np.array(input_ids)).cuda(),
|
||||
np.array(pixel_values),
|
||||
[None],
|
||||
[offset],
|
||||
*params,
|
||||
)
|
||||
predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False)
|
||||
output_ids.append(predict_ids[0][0])
|
||||
for i in range(16):
|
||||
predict_ids = decode(
|
||||
i,
|
||||
model,
|
||||
tp_rank=0,
|
||||
batch_size=1,
|
||||
predict_ids=predict_ids,
|
||||
params=params,
|
||||
print_logits=False,
|
||||
)
|
||||
output_ids.append(predict_ids[0][0])
|
||||
|
||||
# detokenization
|
||||
output = processor.tokenizer.batch_decode(
|
||||
[output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
assert (
|
||||
output
|
||||
== "The image features a man standing on the back of a yellow taxi cab, holding"
|
||||
)
|
||||
|
||||
|
||||
def test_generate(model_path, tp_size):
|
||||
workers = []
|
||||
for tp_rank in range(tp_size):
|
||||
proc = multiprocessing.Process(
|
||||
target=test_generate_worker,
|
||||
args=(
|
||||
model_path,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
workers.append(proc)
|
||||
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_generate("liuhaotian/llava-v1.5-7b", 1)
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.launch_server --model-path /model/llama-classification
|
||||
|
||||
python3 test_httpserver_classify.py
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user