diff --git a/README.md b/README.md index 2ac666c6b..90822b176 100644 --- a/README.md +++ b/README.md @@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7 ``` - See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance. +- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port. +``` +# Node 0 +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0 + +# Node 1 +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1 +``` ### Supported Models - Llama diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index cfd96b54c..0bb26ee15 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -96,8 +96,11 @@ def run_one_batch_size(bs): ret = response.json() print(ret) + input_len = args.input_len if args.input_len else 1 + output_len = max_new_tokens + output_throughput = bs * max_new_tokens / latency - overall_throughput = bs * (args.input_len + max_new_tokens) / latency + overall_throughput = bs * (input_len + output_len) / latency print(f"latency: {latency:.2f} s") print(f"decode throughput: {output_throughput:.2f} token/s") print(f"overall throughput: {overall_throughput:.2f} token/s") diff --git a/benchmark/latency_throughput/bench_serving.py b/benchmark/latency_throughput/bench_serving.py index 23e8245f2..24816d4bd 100644 --- a/benchmark/latency_throughput/bench_serving.py +++ b/benchmark/latency_throughput/bench_serving.py @@ -312,6 +312,9 @@ def main(args: argparse.Namespace): np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time ) + #latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY] + #print(latencies) + print(f"Total time: {benchmark_time:.2f} s") print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") print(f"Decoding throughput: {decoding_throughput:.2f} token/s") diff --git a/python/sglang/README.md b/python/sglang/README.md index c8c093706..2f298c2c3 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -2,11 +2,10 @@ - `backend`: Various backends for the language interpreter. - `lang`: The frontend language. -- `srt`: The runtime for running local models. +- `srt`: The serving engine for running local models. (SRT = SGLang Runtime). - `test`: Test utilities. - `api.py`: Public API. - `bench_latency.py`: Benchmark utilities. - `global_config.py`: The global configs and constants. - `launch_server.py`: The entry point of launching local server. - `utils.py`: Common utilities. - diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py index 72e3bed80..ea942093a 100644 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ b/python/sglang/srt/managers/controller/manager_multi.py @@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum): class Controller: + """A controller that manages multiple data parallel workers.""" + def __init__( self, load_balance_method: str, @@ -183,9 +185,11 @@ def start_controller_process( except Exception: pipe_writer.send(get_exception_traceback()) raise - pipe_writer.send("init ok") - loop = asyncio.get_event_loop() + + loop = asyncio.new_event_loop() + loop.set_default_executor(ThreadPoolExecutor(max_workers=256)) + asyncio.set_event_loop(loop) loop.create_task(controller.loop_for_recv_requests()) loop.run_until_complete(controller.loop_for_forward()) diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index 4c2720733..c2cb922fc 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -1,28 +1,104 @@ """A controller that manages a group of tensor parallel workers.""" -import asyncio +import multiprocessing import logging -from concurrent.futures import ThreadPoolExecutor +import os +import pickle -import uvloop +import torch +import torch.distributed as dist import zmq import zmq.asyncio -from sglang.global_config import global_config -from sglang.srt.managers.controller.tp_worker import ModelTpClient -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.managers.controller.tp_worker import ModelTpServer +from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger = logging.getLogger("srt.controller") +def run_tp_server( + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + model_port_args: ModelPortArgs, + model_overide_args: dict, +): + """Run a tp server.""" + try: + model_server = ModelTpServer( + gpu_id, + tp_rank, + server_args, + model_port_args, + model_overide_args, + ) + tp_cpu_group = model_server.model_runner.tp_group.cpu_group + + while True: + recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group) + model_server.exposed_step(recv_reqs) + except Exception: + logger.error("Exception in run_tp_server:\n" + get_exception_traceback()) + raise + + +def launch_tp_servers(gpu_ids, tp_rank_range, server_args, + model_port_args, model_overide_args): + """Launch multiple tp servers.""" + procs = [] + for i in tp_rank_range: + proc = multiprocessing.Process(target=run_tp_server, args=( + gpu_ids[i], i, server_args, model_port_args, model_overide_args + )) + proc.start() + procs.append(proc) + + return procs + + +def broadcast_recv_input(data, rank, dist_group): + """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" + + if rank == 0: + if len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + tensor_data = torch.ByteTensor(list(serialized_data)) + tensor_size = torch.tensor([size], dtype=torch.long) + + dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_data, src=0, group=dist_group) + else: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty(size, dtype=torch.uint8) + dist.broadcast(tensor_data, src=0, group=dist_group) + + serialized_data = bytes(tensor_data.tolist()) + data = pickle.loads(serialized_data) + return data + + class ControllerSingle: - def __init__(self, model_client: ModelTpClient, port_args: PortArgs): + """A controller that manages a group of tensor parallel workers.""" + + def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict): + # Parse args + self.server_args = server_args + # Init communication - context = zmq.asyncio.Context(2) + context = zmq.Context(2) self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") @@ -31,44 +107,52 @@ class ControllerSingle: f"tcp://127.0.0.1:{port_args.detokenizer_port}" ) - # Init status - self.model_client = model_client - self.recv_reqs = [] + # Init model server + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] - # Init some configs - self.request_dependency_delay = global_config.request_dependency_delay + # Launch other tp ranks + if tp_size_local > 1: + tp_rank_range = range(1, tp_size_local) + self.tp_procs = launch_tp_servers( + gpu_ids, tp_rank_range, server_args, + port_args.model_port_args[0], model_overide_args) - async def loop_for_forward(self): + # Launch tp rank 0 + self.tp_server = ModelTpServer( + gpu_ids[0], + 0, + server_args, + port_args.model_port_args[0], + model_overide_args, + ) + self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group + + def loop_for_forward(self): while True: - next_step_input = list(self.recv_reqs) - self.recv_reqs = [] - out_pyobjs = await self.model_client.step(next_step_input) + recv_reqs = self.recv_requests() + + if self.server_args.tp_size > 1: + broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group) + + out_pyobjs = self.tp_server.exposed_step(recv_reqs) for obj in out_pyobjs: self.send_to_detokenizer.send_pyobj(obj) - # async sleep for receiving the subsequent request and avoiding cache miss - slept = False - if len(out_pyobjs) != 0: - has_finished = any( - [obj.finished_reason is not None for obj in out_pyobjs] - ) - if has_finished: - if self.request_dependency_delay > 0: - slept = True - await asyncio.sleep(self.request_dependency_delay) - - if not slept: - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def loop_for_recv_requests(self): + def recv_requests(self): + recv_reqs = [] while True: - recv_req = await self.recv_from_tokenizer.recv_pyobj() - self.recv_reqs.append(recv_req) + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + recv_reqs.append(recv_req) + except zmq.ZMQError: + break + return recv_reqs def start_controller_process( - server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args + server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict ): logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), @@ -76,27 +160,18 @@ def start_controller_process( ) try: - tp_size_local = server_args.tp_size // server_args.nnodes - model_client = ModelTpClient( - [i for _ in range(server_args.nnodes) for i in range(tp_size_local)], - server_args, - port_args.model_port_args[0], - model_overide_args, - ) - controller = ControllerSingle(model_client, port_args) + controller = ControllerSingle(server_args, port_args, model_overide_args) except Exception: pipe_writer.send(get_exception_traceback()) raise pipe_writer.send("init ok") - loop = asyncio.new_event_loop() - loop.set_default_executor(ThreadPoolExecutor(max_workers=256)) - asyncio.set_event_loop(loop) - loop.create_task(controller.loop_for_recv_requests()) try: - loop.run_until_complete(controller.loop_for_forward()) + controller.loop_for_forward() except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: + for t in controller.tp_procs: + os.kill(t.pid, 9) kill_parent_process() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index d68d9af32..80c40e4f5 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig -from vllm.distributed import init_distributed_environment, initialize_model_parallel +from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry @@ -75,6 +75,7 @@ class ModelRunner: distributed_init_method=nccl_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + self.tp_group = get_tp_group() total_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 1d22dfdf1..21569c966 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -53,7 +53,7 @@ class ModelTpServer: tp_rank: int, server_args: ServerArgs, model_port_args: ModelPortArgs, - model_overide_args, + model_overide_args: dict, ): server_args, model_port_args = obtain(server_args), obtain(model_port_args) suppress_other_loggers() @@ -178,7 +178,7 @@ class ModelTpServer: self.new_token_ratio_recovery = global_config.new_token_ratio_recovery def exposed_step(self, recv_reqs): - if self.tp_size * self.dp_size != 1: + if not isinstance(recv_reqs, list): recv_reqs = obtain(recv_reqs) try: @@ -206,11 +206,11 @@ class ModelTpServer: @torch.inference_mode() def forward_step(self): - new_batch = self.get_new_fill_batch() + new_batch = self.get_new_prefill_batch() if new_batch is not None: - # Run a new fill batch - self.forward_fill_batch(new_batch) + # Run a new prefill batch + self.forward_prefill_batch(new_batch) self.cache_filled_batch(new_batch) if not new_batch.is_empty(): @@ -219,7 +219,7 @@ class ModelTpServer: else: self.running_batch.merge(new_batch) else: - # Run decode batch + # Run a decode batch if self.running_batch is not None: # Run a few decode batches continuously for reducing overhead for _ in range(global_config.num_continue_decode_steps): @@ -312,7 +312,7 @@ class ModelTpServer: ) self.forward_queue.append(req) - def get_new_fill_batch(self) -> Optional[Batch]: + def get_new_prefill_batch(self) -> Optional[Batch]: running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 ) @@ -436,7 +436,7 @@ class ModelTpServer: self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] return new_batch - def forward_fill_batch(self, batch: Batch): + def forward_prefill_batch(self, batch: Batch): # Build batch tensors batch.prepare_for_extend( self.model_config.vocab_size, self.int_token_logit_bias @@ -746,8 +746,8 @@ class ModelTpClient: # Init model assert len(gpu_ids) == 1 self.model_server = ModelTpService().exposed_ModelTpServer( - 0, gpu_ids[0], + 0, server_args, model_port_args, model_overide_args, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 6cda67dea..0a3f53b8b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import ( start_controller_process as start_controller_process_multi, ) from sglang.srt.managers.controller.manager_single import ( + launch_tp_servers, start_controller_process as start_controller_process_single, ) -from sglang.srt.managers.controller.tp_worker import ModelTpService from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -53,7 +53,6 @@ from sglang.srt.utils import ( enable_show_time_cost, receive_addrs, send_addrs_to_rank_0, - start_rpyc_service_process, ) from sglang.utils import get_exception_traceback @@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg model_port_args=model_port_args, ) - # TODO multi-node dp is not supported - assert not (server_args.dp_size > 1 and server_args.node_rank is not None) + # Handle multi-node tp if server_args.nnodes > 1: + assert server_args.dp_size == 1, "Multi-node dp is not supported." + if server_args.node_rank != 0: - send_addrs_to_rank_0(model_port_args[0], server_args) - else: - receive_addrs(model_port_args[0], server_args) - for i in range(tp_size_local): - start_rpyc_service_process( - ModelTpService, model_port_args[0].model_tp_ports[i] - ) - if server_args.node_rank != 0: - logger.info( - f"[node_rank={server_args.node_rank}]: Listen for connections..." - ) + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + tp_rank_range = list(range(server_args.node_rank * tp_size_local, + (server_args.node_rank + 1) * tp_size_local)) + procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args, + port_args.model_port_args[0], model_overide_args) while True: pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 46dfc25d2..b4f79c066 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -67,10 +67,12 @@ class ServerArgs: if self.tokenizer_path is None: self.tokenizer_path = self.model_path if self.mem_fraction_static is None: - if self.tp_size >= 8: + if self.tp_size >= 16: + self.mem_fraction_static = 0.74 + elif self.tp_size >= 8: self.mem_fraction_static = 0.78 elif self.tp_size >= 4: - self.mem_fraction_static = 0.80 + self.mem_fraction_static = 0.82 elif self.tp_size >= 2: self.mem_fraction_static = 0.85 else: