diff --git a/test/srt/model/reference_hf.py b/playground/reference_hf.py similarity index 69% rename from test/srt/model/reference_hf.py rename to playground/reference_hf.py index e63866f02..ca82871c9 100644 --- a/test/srt/model/reference_hf.py +++ b/playground/reference_hf.py @@ -1,5 +1,28 @@ +""" +Usage: +python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4 + +Reference output: + The capital of France is Paris. +The capital of the United States is Washington, D.C. +The capital of Canada is Ottawa. +The capital of Japan is Tokyo +prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141], + device='cuda:0') + The capital of the United Kindom is London. +The capital of the United Kingdom is London. +The capital of the United Kingdom is London. +The capital of the United Kingdom is London. +prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742], + device='cuda:0') + Today is a sunny day and I like to go for a walk in the park. +I'm going to the park to play in the grass and water. +Today is a very +prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609], + device='cuda:0') +""" + import argparse -import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -40,7 +63,6 @@ def normal_text(args): @torch.inference_mode() def synthetic_tokens(args): - t = AutoTokenizer.from_pretrained(args.model_path) m = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True ) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py new file mode 100644 index 000000000..900272282 --- /dev/null +++ b/python/sglang/bench_latency.py @@ -0,0 +1,288 @@ +""" +Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py. + +# Usage (latency test): +python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy + +# Usage (correctness test): +python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct + +### Reference output: +prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], + [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], + [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], + device='cuda:0', dtype=torch.float16) +prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141], + [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742], + [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]], + device='cuda:0', dtype=torch.float16) + The capital of France is. +The capital of the United States is Washington, D.C. + + The capital of the United Kindom is. +The capital of the United Kingdom is London. +The capital of the + Today is a sunny day and I like go for a walk in the park. +I'm going to the park +""" + +import argparse +import dataclasses +import logging +import multiprocessing +import time + +import numpy as np +import torch +import torch.distributed as dist + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req +from sglang.srt.managers.controller.model_runner import ModelRunner +from sglang.srt.model_config import ModelConfig +from sglang.srt.sampling_params import SamplingParams +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import suppress_other_loggers + + +@dataclasses.dataclass +class BenchArgs: + batch_size: int = 1 + input_len: int = 1024 + output_len: int = 4 + correctness_test: bool = False + # This is only used for correctness test + cut_len: int = 4 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size) + parser.add_argument("--input-len", type=int, default=BenchArgs.input_len) + parser.add_argument("--output-len", type=int, default=BenchArgs.output_len) + parser.add_argument("--correctness-test", action="store_true") + parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + attrs = [attr.name for attr in dataclasses.fields(cls)] + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + +def load_model(server_args, tp_rank): + suppress_other_loggers() + + model_config = ModelConfig(path=server_args.model_path) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=tp_rank, + tp_rank=tp_rank, + tp_size=server_args.tp_size, + nccl_port=28888, + server_args=server_args, + ) + tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + if server_args.tp_size > 1: + dist.barrier() + return model_runner, tokenizer + + +def prepare_inputs(bench_args, tokenizer): + prompts = [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + input_ids = [tokenizer.encode(p) for p in prompts] + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(prompts)): + assert len(input_ids[i]) > bench_args.cut_len + + tmp_input_ids = input_ids[i][:bench_args.cut_len] + req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids) + req.prefix_indices = [] + req.sampling_params = sampling_params + req.input_ids = req.origin_input_ids + reqs.append(req) + + return input_ids, reqs + + +def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): + for i in range(len(reqs)): + req = reqs[i] + req.input_ids += input_ids[i][bench_args.cut_len:] + req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ + i, :bench_args.cut_len + ] + return reqs + + +def prepare_synthetic_inputs(bench_args, tokenizer): + input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32) + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=BenchArgs.output_len, + ) + + reqs = [] + for i in range(len(input_ids)): + req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i])) + req.prefix_indices = [] + req.sampling_params = sampling_params + req.input_ids = req.origin_input_ids + reqs.append(req) + + return reqs + + +def extend(reqs, model_runner): + batch = Batch.init_new( + reqs=reqs, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + tree_cache=None) + batch.prepare_for_extend(model_runner.model_config.vocab_size, None) + output = model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids, _ = batch.sample(output.next_token_logits) + return next_token_ids, output.next_token_logits, batch + + +def decode(input_token_ids, batch, model_runner): + batch.prepare_for_decode(input_token_ids.cpu().numpy()) + output = model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids, _ = batch.sample(output.next_token_logits) + return next_token_ids, output.next_token_logits + + +def correctness_test( + server_args, + bench_args, + tp_rank, +): + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, tp_rank) + + # Prepare inputs + input_ids, reqs = prepare_inputs(bench_args, tokenizer) + + # Prefill + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print("prefill logits (first half)", next_token_logits) + + # Prepare extend inputs + reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner) + + # Extend + next_token_ids, next_token_logits, batch = extend(reqs, model_runner) + rank_print("prefill logits (final)", next_token_logits) + + # Decode + output_ids = [list(req.input_ids) for req in reqs] + for _ in range(bench_args.output_len): + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + for i in range(len(reqs)): + output_ids[i].append(next_token_ids[i]) + + # Print + for i in range(len(reqs)): + print(tokenizer.decode(output_ids[i])) + + +def latency_test( + server_args, + bench_args, + tp_rank, +): + rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None + + # Load the model + model_runner, tokenizer = load_model(server_args, tp_rank) + + # Prepare inputs + reqs = prepare_synthetic_inputs(bench_args, tokenizer) + + def clear(): + model_runner.req_to_token_pool.clear() + model_runner.token_to_kv_pool.clear() + + @torch.inference_mode() + def run_once(output_len): + # Prefill + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _, batch = extend(reqs, model_runner) + torch.cuda.synchronize() + latency = time.time() - tic + throughput = bench_args.input_len * bench_args.batch_size / latency + rank_print(f"Prefill. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s") + + # Decode + for _ in range(output_len): + torch.cuda.synchronize() + tic = time.time() + next_token_ids, _ = decode(next_token_ids, batch, model_runner) + torch.cuda.synchronize() + latency = time.time() - tic + throughput = bench_args.batch_size / latency + rank_print(f"Decode. latency: {latency:6.3f} ms, throughput: {throughput:9.2f} token/s") + + # Warm up + run_once(4) + clear() + + # Run again + run_once(bench_args.output_len) + + +def main(server_args, bench_args): + print(bench_args) + + if bench_args.correctness_test: + work_func = correctness_test + else: + work_func = latency_test + + workers = [] + for tp_rank in range(server_args.tp_size): + proc = multiprocessing.Process( + target=work_func, + args=( + server_args, + bench_args, + tp_rank, + ), + ) + proc.start() + workers.append(proc) + + for proc in workers: + proc.join() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + BenchArgs.add_cli_args(parser) + args = parser.parse_args() + + server_args = ServerArgs.from_cli_args(args) + bench_args = BenchArgs.from_cli_args(args) + + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + main(server_args, bench_args) \ No newline at end of file diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 942f29070..84ecd98fe 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -23,6 +23,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, is_multimodal_model, + monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, ) @@ -229,6 +230,7 @@ class ModelRunner: self.nccl_port = nccl_port self.server_args = server_args self.is_multimodal_model = is_multimodal_model(self.model_config) + monkey_patch_vllm_dummy_weight_loader() # Init torch distributed logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f93e0be36..94cabee50 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -466,6 +466,48 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) +def monkey_patch_vllm_dummy_weight_loader(): + """ + Monkey patch the dummy weight loader in vllm to call process_weights_after_loading. + """ + + from vllm.model_executor.model_loader.loader import ( + ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig, + ParallelConfig, SchedulerConfig, CacheConfig, nn, + set_default_torch_dtype, _initialize_model, initialize_dummy_weights, + DummyModelLoader + ) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + setattr(DummyModelLoader, "load_model", load_model) + + API_KEY_HEADER_NAME = "X-API-Key" diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py deleted file mode 100644 index 339574228..000000000 --- a/test/srt/model/bench_llama_low_api.py +++ /dev/null @@ -1,275 +0,0 @@ -import multiprocessing as mp -import time -from dataclasses import dataclass - -import torch -import torch.distributed as dist - -from sglang.srt.managers.controller.model_runner import ModelRunner -from sglang.srt.model_config import ModelConfig - - -@dataclass -class BenchBatch: - req_to_token_pool: torch.Tensor - token_to_kv_pool: torch.Tensor - - input_ids: torch.Tensor = None - position_ids_offsets: torch.Tensor = None - seq_lens: torch.Tensor = None - prefix_lens: torch.Tensor = None - req_pool_indices: torch.Tensor = None - out_cache_loc: torch.Tensor = None - out_cache_cont_start: torch.Tensor = None - out_cache_cont_end: torch.Tensor = None - - def __init__(self, model_runner: ModelRunner): - self.req_to_token_pool = model_runner.req_to_token_pool - self.token_to_kv_pool = model_runner.token_to_kv_pool - - def init_prefill_batch(self, input_ids, batch_size, seq_len): - self.input_ids = input_ids - self.position_ids_offsets = torch.zeros( - batch_size, dtype=torch.int32, device="cuda" - ) - self.seq_lens = torch.full( - (batch_size,), seq_len, dtype=torch.int32, device="cuda" - ) - self.prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - self.req_pool_indices = self.req_to_token_pool.alloc(batch_size) - self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * seq_len) - - for i in range(batch_size): - n_idx = self.req_pool_indices[i].item() - self.req_to_token_pool.req_to_token[n_idx, :seq_len] = self.out_cache_loc[ - i * seq_len : (i + 1) * seq_len - ] - - def update_extend( - self, input_ids, batch_size, prefix_len, extend_len, prefix_req_idx - ): - self.input_ids = input_ids - self.position_ids_offsets = torch.zeros( - batch_size, dtype=torch.int32, device="cuda" - ) - self.seq_lens = torch.full( - (batch_size,), prefix_len + extend_len, dtype=torch.int32, device="cuda" - ) - self.prefix_lens = torch.full( - (batch_size,), prefix_len, dtype=torch.int32, device="cuda" - ) - self.req_pool_indices = self.req_to_token_pool.alloc(batch_size) - self.out_cache_loc = self.token_to_kv_pool.alloc(batch_size * extend_len) - - req_to_token = self.req_to_token_pool.req_to_token - fork_num = batch_size // prefix_req_idx.shape[0] - for i in range(batch_size): - p_idx = prefix_req_idx[i // fork_num].item() - n_idx = self.req_pool_indices[i].item() - req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] - req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( - self.out_cache_loc[i * extend_len : (i + 1) * extend_len] - ) - - def update_decode(self, predict_ids, batch_size): - assert predict_ids.shape[0] == batch_size - assert batch_size == self.req_pool_indices.shape[0] - - self.input_ids = predict_ids.reshape(-1) - self.prefix_lens = None - ( - self.out_cache_loc, - self.out_cache_cont_start, - self.out_cache_cont_end, - ) = self.token_to_kv_pool.alloc_contiguous(batch_size) - self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( - self.out_cache_loc - ) - self.seq_lens.add_(1) - - -def prefill(model_runner: ModelRunner, batch: BenchBatch): - logits, _ = model_runner.forward_extend( - batch.input_ids, - batch.req_pool_indices, - batch.seq_lens, - batch.prefix_lens, - batch.position_ids_offsets, - batch.out_cache_loc, - False, - ) - - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - return predict_ids - - -def extend(model_runner: ModelRunner, batch: BenchBatch): - logits, _ = model_runner.forward_extend( - batch.input_ids, - batch.req_pool_indices, - batch.seq_lens, - batch.prefix_lens, - batch.position_ids_offsets, - batch.out_cache_loc, - True, - ) - - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - return predict_ids - - -def decode(model_runner: ModelRunner, batch: BenchBatch): - logits = model_runner.forward_decode( - batch.input_ids, - batch.req_pool_indices, - batch.seq_lens, - None, - batch.position_ids_offsets, - None, - batch.out_cache_cont_start, - batch.out_cache_cont_end, - ) - - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - return predict_ids - - -def bench_generate_worker( - model_path, - tp_rank, - tp_size, - shared_num, - unique_num, - shared_len, - unique_len, - decode_len, - server_args_dict, -): - assert unique_num % shared_num == 0 - - model_config = ModelConfig(path=model_path) - model_runner = ModelRunner( - model_config=model_config, - mem_fraction_static=0.8, - tp_rank=tp_rank, - tp_size=tp_size, - nccl_port=28888, - server_args_dict=server_args_dict, - ) - - batch = BenchBatch(model_runner) - - # warm up - for _ in range(1): - input_ids = torch.randint( - low=5, high=100, size=(shared_num * shared_len,) - ).cuda() - batch.init_prefill_batch(input_ids, shared_num, shared_len) - _ = prefill(model_runner, batch) - - input_ids = torch.randint( - low=5, high=100, size=(unique_num * unique_len,) - ).cuda() - batch.update_extend( - input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices - ) - predict_ids = extend(model_runner, batch) - - for i in range(decode_len): - predict_ids = torch.from_numpy(predict_ids).cuda() - batch.update_decode(predict_ids, unique_num) - predict_ids = decode(model_runner, batch) - - model_runner.req_to_token_pool.clear() - model_runner.token_to_kv_pool.clear() - - if tp_size > 1: - dist.barrier() - - prefill_start = time.time() - input_ids = torch.randint(low=5, high=100, size=(shared_num * shared_len,)).cuda() - batch.init_prefill_batch(input_ids, shared_num, shared_len) - _ = prefill(model_runner, batch) - if tp_rank == 0: - print(f"prefill: {(time.time() - prefill_start) * 1000:.2f} ms") - - extend_start = time.time() - input_ids = torch.randint(low=5, high=100, size=(unique_num * unique_len,)).cuda() - batch.update_extend( - input_ids, unique_num, shared_len, unique_len, batch.req_pool_indices - ) - predict_ids = extend(model_runner, batch) - if tp_rank == 0: - print(f"extend: {(time.time() - extend_start) * 1000:.2f} ms") - - for i in range(decode_len): - decode_start = time.time() - predict_ids = torch.from_numpy(predict_ids).cuda() - batch.update_decode(predict_ids, unique_num) - predict_ids = decode(model_runner, batch) - if tp_rank == 0: - print(f"decode {i}: {(time.time() - decode_start) * 1000:.2f} ms") - - -def bench_generate( - model_path, - tp_size, - shared_num, - unique_num, - shared_len, - unique_len, - decode_len, - server_args_dict, -): - print( - f"tp_size: {tp_size}, " - f"shared_num: {shared_num}, " - f"unique_num: {unique_num}, " - f"shared_len: {shared_len}, " - f"unique_len: {unique_len}, " - f"decode_len: {decode_len}, " - f"server_args: {server_args_dict}" - ) - workers = [] - for tp_rank in range(tp_size): - proc = mp.Process( - target=bench_generate_worker, - args=( - model_path, - tp_rank, - tp_size, - shared_num, - unique_num, - shared_len, - unique_len, - decode_len, - server_args_dict, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - -if __name__ == "__main__": - bench_generate( - model_path="meta-llama/Llama-2-7b-chat-hf", - tp_size=1, - shared_num=1, - unique_num=32, - shared_len=256, - unique_len=256, - decode_len=8, - server_args_dict={}, - ) diff --git a/test/srt/model/test_llama_extend.py b/test/srt/model/test_llama_extend.py deleted file mode 100644 index 2814dc2a0..000000000 --- a/test/srt/model/test_llama_extend.py +++ /dev/null @@ -1,109 +0,0 @@ -import multiprocessing -import os -import time - -import numpy as np -import torch -import torch.distributed as dist -import transformers - -from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req -from sglang.srt.managers.controller.model_runner import ModelRunner -from sglang.srt.model_config import ModelConfig -from sglang.srt.sampling_params import SamplingParams - - -def test_generate_worker(model_path, tp_rank, tp_size): - model_config = ModelConfig(path=model_path) - model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - - # Input - prompts = [ - "The capital of France is", - "Today is a sunny day and I like", - ] - sampling_params = SamplingParams(temperature=0) - - cut_num = 4 - - reqs = [] - for i in range(len(prompts)): - input_ids = tokenizer.encode(prompts[i])[:cut_num] - req = Req(i, prompts[i], input_ids) - req.sampling_params = sampling_params - reqs.append(req) - - # Prefill - batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) - batch.prepare_for_extend(model.model_config.vocab_size, None) - logits, _ = model.forward(batch, ForwardMode.EXTEND) - next_token_ids, next_token_probs = batch.sample(logits) - print("extend logits (first)", logits) - - # Extend - for i in range(len(prompts)): - req = reqs[i] - req.input_ids += tokenizer.encode(prompts[i])[cut_num:] - req.prefix_indices = model.req_to_token_pool.req_to_token[ - batch.req_pool_indices[i], :cut_num - ] - batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) - batch.prepare_for_extend(model.model_config.vocab_size, None) - logits, _ = model.forward(batch, ForwardMode.EXTEND) - next_token_ids, next_token_probs = batch.sample(logits) - - print("extend logits", logits) - print( - "next_token_ids", next_token_ids, [tokenizer.decode(x) for x in next_token_ids] - ) - - # Decode - for i in range(6): - batch.prepare_for_decode(next_token_ids.cpu().numpy()) - logits, _ = model.forward(batch, ForwardMode.DECODE) - next_token_ids, next_token_probs = batch.sample(logits) - - print( - "next_token_ids", - next_token_ids, - [tokenizer.decode(x) for x in next_token_ids], - ) - - -def test_generate(model_path, tp_size): - workers = [] - for tp_rank in range(tp_size): - proc = multiprocessing.Process( - target=test_generate_worker, - args=( - model_path, - tp_rank, - tp_size, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - -if __name__ == "__main__": - os.environ["TOKENIZERS_PARALLELISM"] = "false" - test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1) - - # Reference output for TinyLlama-1.1B-Chat-v0.4 - # extend logits (first) tensor([[-10.0312, -9.5000, 0.8896, ..., -4.9375, -3.2402, -3.3633], - # [ -9.1797, -10.2500, 2.7168, ..., -4.3359, -4.0664, -4.1289]], - # device='cuda:0', dtype=torch.float16) - # extend logits tensor([[-8.3125, -7.1172, 3.3359, ..., -4.9531, -4.1289, -3.4121], - # [-9.6406, -9.0547, 4.0195, ..., -5.3086, -4.7188, -4.4609]], - # device='cuda:0', dtype=torch.float16) - # next_token_ids tensor([3681, 304], device='cuda:0') ['Paris', 'to'] - # next_token_ids tensor([29889, 748], device='cuda:0') ['.', 'go'] - # next_token_ids tensor([ 13, 363], device='cuda:0') ['\n', 'for'] - # next_token_ids tensor([1576, 263], device='cuda:0') ['The', 'a'] - # next_token_ids tensor([7483, 6686], device='cuda:0') ['capital', 'walk'] - # next_token_ids tensor([310, 297], device='cuda:0') ['of', 'in'] - # next_token_ids tensor([278, 278], device='cuda:0') ['the', 'the'] diff --git a/test/srt/model/test_llama_low_api.py b/test/srt/model/test_llama_low_api.py deleted file mode 100644 index 0eb1574b1..000000000 --- a/test/srt/model/test_llama_low_api.py +++ /dev/null @@ -1,211 +0,0 @@ -import multiprocessing -import time - -import numpy as np -import torch -import torch.distributed as dist - -from sglang.srt.managers.controller.model_runner import ModelRunner -from sglang.srt.model_config import ModelConfig - - -def test_generate_worker( - model_path, tp_rank, tp_size, batch_size, input_len, output_len -): - model_config = ModelConfig(path=model_path) - model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) - - # Prepare data - input_ids = np.vstack([np.arange(5, input_len + 5) for _ in range(batch_size)]) - input_ids = input_ids.reshape(-1) - input_ids = torch.tensor(input_ids).cuda() - - def init_batch_data(model, batch_size, input_len): - req_pool_indices = model.req_to_token_pool.alloc(batch_size) - seq_lens = torch.full( - (batch_size,), input_len, dtype=torch.int32, device="cuda" - ) - prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len) - for i in range(batch_size): - req_idx = req_pool_indices[i].item() - model.req_to_token_pool.req_to_token[req_idx, :input_len] = out_cache_loc[ - i * input_len : (i + 1) * input_len - ] - - return ( - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - ) - - def prefill(print_logits): - nonlocal predict_ids - - logits, _ = model.forward_prefill( - input_ids, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - False, - ) - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - if print_logits and tp_rank == 0: - print("prefill logits", logits, logits.shape) - - def decode(print_logits): - nonlocal predict_ids - - ( - out_cache_loc, - out_cache_cont_start, - out_cache_cont_end, - ) = model.token_to_kv_pool.alloc_contiguous(batch_size) - model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc - seq_lens.add_(1) - logits, _ = model.forward_decode( - torch.from_numpy(predict_ids).cuda().reshape(-1), - req_pool_indices, - seq_lens, - None, - position_ids_offsets, - None, - out_cache_cont_start, - out_cache_cont_end, - False, - ) - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - if print_logits and tp_rank == 0: - print("decode", i, logits) - - # Warm up - ( - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - ) = init_batch_data(model, batch_size, input_len) - predict_ids = None - - prefill(True) - for i in range(output_len): - decode(True) - - for i in range(batch_size): - req_idx = req_pool_indices[i].item() - model.token_to_kv_pool.dec_refs( - model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]] - ) - model.req_to_token_pool.free(req_pool_indices) - - # Benchmark - if tp_size > 1: - dist.barrier() - start_time = prefill_start_time = time.time() - - ( - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - ) = init_batch_data(model, batch_size, input_len) - - prefill(False) - - if tp_rank == 0: - print(f"prefill cost: {(time.time() - prefill_start_time) * 1000:.2f} ms") - - for i in range(output_len): - step_start = time.time() - - decode(False) - - step_end = time.time() - - if i % 100 == 0 or i == output_len - 1: - if tp_rank == 0: - print(f"step {i} cost: {(step_end - step_start) * 1000:.2f} ms") - - end_time = time.time() - - if tp_rank == 0: - print(f"total cost: {(end_time - start_time) * 1000:.2f}") - - -def test_generate(model_path, tp_size, batch_size, input_len, output_len): - workers = [] - for tp_rank in range(tp_size): - proc = multiprocessing.Process( - target=test_generate_worker, - args=( - model_path, - tp_rank, - tp_size, - batch_size, - input_len, - output_len, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - -if __name__ == "__main__": - test_generate("TinyLlama/TinyLlama-1.1B-Chat-v0.4", 1, 1, 256, 8) - # test_generate("meta-llama/Llama-2-7b-chat-hf", 1, 16, 256, 8) - - # Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 32, 8) - # prefill logits tensor([[-1.3380e-03, 4.4702e-01, 2.9082e+00, ..., -1.8398e+00, - # 1.8281e+00, 2.1816e+00]], device='cuda:0') - # decode 0 tensor([[-0.3904, 0.8784, 3.6934, ..., -2.4473, 1.5811, 2.0098]], - # device='cuda:0') - # decode 1 tensor([[-0.3552, 0.0635, 2.5781, ..., -2.5820, 1.3047, 1.7607]], - # device='cuda:0') - # decode 2 tensor([[-1.5645, -1.1963, 3.8145, ..., -2.9766, 1.0244, 1.0645]], - # device='cuda:0') - # decode 3 tensor([[-1.3682, -0.6548, 4.2734, ..., -2.8711, 1.1172, 1.1494]], - # device='cuda:0') - # decode 4 tensor([[-1.0205, -0.0060, 4.4844, ..., -2.7090, 1.6143, 1.8135]], - # device='cuda:0') - # decode 5 tensor([[ 0.4260, 1.6006, 4.3633, ..., -2.2480, 2.5547, 2.8379]], - # device='cuda:0') - # decode 6 tensor([[ 0.7095, 2.1816, 5.0078, ..., -2.1309, 3.0293, 3.0840]], - # device='cuda:0') - # decode 7 tensor([[-0.2883, 1.1289, 4.7188, ..., -2.4023, 2.1055, 2.1836]], - # device='cuda:0') - - # Reference output for TinyLlama-1.1B-Chat-v0.4 (1, 256, 8) - # prefill logits tensor([[-2.5840, -2.7227, 6.8047, ..., -2.3613, 0.1224, 0.5952]], - # device='cuda:0') - # decode 0 tensor([[-0.6235, -0.7690, 9.2891, ..., -1.4922, 2.8008, 2.9531]], - # device='cuda:0') - # decode 1 tensor([[-1.3662, -1.4648, 7.1250, ..., -1.7861, 1.7363, 1.8857]], - # device='cuda:0') - # decode 2 tensor([[-0.8540, -0.5947, 9.1328, ..., -2.1211, 2.9707, 2.8945]], - # device='cuda:0') - # decode 3 tensor([[ 0.0652, 1.0312, 8.1250, ..., -2.0586, 3.4727, 3.6172]], - # device='cuda:0') - # decode 4 tensor([[-0.0459, 1.0098, 9.1406, ..., -2.1797, 3.8320, 3.9355]], - # device='cuda:0') - # decode 5 tensor([[ 0.2964, 1.3564, 9.8828, ..., -2.1602, 4.1836, 4.2422]], - # device='cuda:0') - # decode 6 tensor([[ 0.6475, 1.8105, 10.1250, ..., -2.0098, 4.2578, 4.4062]], - # device='cuda:0') - # decode 7 tensor([[ 0.4985, 1.4746, 9.9062, ..., -1.9141, 3.9863, 4.3047]], - # device='cuda:0') diff --git a/test/srt/model/test_llava_low_api.py b/test/srt/model/test_llava_low_api.py deleted file mode 100644 index 2a9fa543d..000000000 --- a/test/srt/model/test_llava_low_api.py +++ /dev/null @@ -1,164 +0,0 @@ -import multiprocessing -import time - -import numpy as np -import torch -import torch.distributed as dist - -from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.managers.controller.infer_batch import ForwardMode -from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner -from sglang.srt.model_config import ModelConfig -from sglang.srt.utils import load_image - - -def init_batch_data(model, batch_size, input_len): - req_pool_indices = model.req_to_token_pool.alloc(batch_size) - seq_lens = torch.full((batch_size,), input_len, dtype=torch.int32, device="cuda") - prefix_lens = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - position_ids_offsets = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - - out_cache_loc = model.token_to_kv_pool.alloc(batch_size * input_len) - for i in range(batch_size): - model.req_to_token_pool.req_to_token[i, :input_len] = out_cache_loc[ - i * input_len : (i + 1) * input_len - ] - - return ( - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - ) - - -def prefill(model, tp_rank, params, print_logits): - logits, _ = model.forward_extend_multi_modal( - *params, - False, - ) - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - if print_logits and tp_rank == 0: - print("prefill logits", logits, logits.shape) - - return predict_ids - - -def decode(step, model, tp_rank, batch_size, predict_ids, params, print_logits): - ( - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - ) = params - - ( - out_cache_loc, - out_cache_cont_start, - out_cache_cont_end, - ) = model.token_to_kv_pool.alloc_contiguous(batch_size) - model.req_to_token_pool.req_to_token[req_pool_indices, seq_lens] = out_cache_loc - seq_lens.add_(1) - logits, _ = model.forward_decode( - torch.from_numpy(predict_ids).cuda().reshape(-1), - req_pool_indices, - seq_lens, - None, - position_ids_offsets, - None, - out_cache_cont_start, - out_cache_cont_end, - False, - ) - prob_out = torch.softmax(logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - if print_logits and tp_rank == 0: - print("decode", step, logits) - return predict_ids - - -def test_generate_worker( - model_path, - tp_rank, - tp_size, -): - model_config = ModelConfig(path=model_path) - model = ModelRunner(model_config, 0.8, tp_rank, tp_size, 28888) - # print(model.model) - - # Prepare data - prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:" - image_path = "/home/ubuntu/sglang/test/lang/test_image.png" - image = load_image(image_path) - - processor = get_processor("llava-hf/llava-1.5-7b-hf") - input_ids = processor.tokenizer.encode(prompt) - pixel_values = processor.image_processor(image)["pixel_values"] - input_ids, offset = model.model.pad_input_ids( - input_ids, - [ - 0, - ], - ) - - params = init_batch_data(model, 1, len(input_ids)) - - # inference - output_ids = [] - prefill_params = ( - torch.tensor(np.array(input_ids)).cuda(), - np.array(pixel_values), - [None], - [offset], - *params, - ) - predict_ids = prefill(model, tp_rank=0, params=prefill_params, print_logits=False) - output_ids.append(predict_ids[0][0]) - for i in range(16): - predict_ids = decode( - i, - model, - tp_rank=0, - batch_size=1, - predict_ids=predict_ids, - params=params, - print_logits=False, - ) - output_ids.append(predict_ids[0][0]) - - # detokenization - output = processor.tokenizer.batch_decode( - [output_ids], skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] - assert ( - output - == "The image features a man standing on the back of a yellow taxi cab, holding" - ) - - -def test_generate(model_path, tp_size): - workers = [] - for tp_rank in range(tp_size): - proc = multiprocessing.Process( - target=test_generate_worker, - args=( - model_path, - tp_rank, - tp_size, - ), - ) - proc.start() - workers.append(proc) - - for proc in workers: - proc.join() - - -if __name__ == "__main__": - test_generate("liuhaotian/llava-v1.5-7b", 1) diff --git a/test/srt/test_httpserver_classify.py b/test/srt/test_httpserver_classify.py index e3b74dc17..40da2b749 100644 --- a/test/srt/test_httpserver_classify.py +++ b/test/srt/test_httpserver_classify.py @@ -1,5 +1,7 @@ """ Usage: +python3 -m sglang.launch_server --model-path /model/llama-classification + python3 test_httpserver_classify.py """