diff --git a/benchmark/latency_throughput/README.md b/benchmark/latency_throughput/README.md index 26ff4ffc2..b303bc29a 100644 --- a/benchmark/latency_throughput/README.md +++ b/benchmark/latency_throughput/README.md @@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat ``` # run synthetic -python3 synthetic_benchmark.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 +python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 ``` @@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha ``` # run synthetic -python3 synthetic_benchmark.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 +python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 ``` diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index 593df054c..fc66b231e 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -24,7 +24,7 @@ if __name__ == "__main__": raise ValueError(f"Invalid backend: {args.backend}") url = f"{args.host}:{args.port}" - a = random.randint(0, 1 << 20) + a = 20 max_new_tokens = 256 prompt = f"{a, }" diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index a8c1e2feb..cb2b07251 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -2,7 +2,8 @@ import argparse -from sglang.srt.server import ServerArgs, launch_server +from sglang.srt.server import launch_server +from sglang.srt.server_args import ServerArgs if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index e4d02d036..8b8625754 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -76,8 +76,9 @@ def start_controller_process( ) try: + tp_size_local = server_args.tp_size // server_args.nnodes model_client = ModelTpClient( - list(range(server_args.tp_size)), + [i for _ in range(server_args.nnodes) for i in range(tp_size_local)], server_args, port_args.model_port_args[0], model_overide_args, diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 692ed7ac3..95f6b4e5a 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -246,12 +246,16 @@ class ModelRunner: torch.cuda.set_device(self.gpu_id) logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") monkey_patch_vllm_p2p_access_check(self.gpu_id) + if server_args.nccl_init_addr: + nccl_init_method = f"tcp://{server_args.nccl_init_addr}" + else: + nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" init_distributed_environment( backend="nccl", world_size=self.tp_size, rank=self.tp_rank, local_rank=self.gpu_id, - distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}", + distributed_init_method=nccl_init_method ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) total_gpu_memory = get_available_gpu_memory( @@ -311,7 +315,7 @@ class ModelRunner: self.gpu_id, distributed=self.tp_size > 1 ) head_dim = self.model_config.head_dim - head_num = self.model_config.num_key_value_heads // self.tp_size + head_num = self.model_config.get_num_kv_heads(self.tp_size) cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static @@ -324,7 +328,7 @@ class ModelRunner: if self.max_total_num_tokens <= 0: raise RuntimeError( - "Not enought memory. Please try to increase --mem-fraction-static." + "Not enough memory. Please try to increase --mem-fraction-static." ) self.req_to_token_pool = ReqToTokenPool( diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index d8fee6537..c49d4b01e 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -37,7 +37,8 @@ from sglang.srt.utils import ( get_int_token_logit_bias, is_multimodal_model, set_random_seed, - start_rpyc_process, + start_rpyc_service_process, + connect_rpyc_service, suppress_other_loggers, ) from sglang.utils import get_exception_traceback @@ -770,12 +771,17 @@ class ModelTpClient: else: with ThreadPoolExecutor(self.tp_size) as executor: # Launch model processes - rets = executor.map( - lambda args: start_rpyc_process(*args), - [(ModelTpService, p) for p in model_port_args.model_tp_ports], - ) - self.model_services = [x[0] for x in rets] - self.procs = [x[1] for x in rets] + if server_args.nnodes == 1: + self.procs = list(executor.map( + lambda args: start_rpyc_service_process(*args), + [(ModelTpService, p) for p in model_port_args.model_tp_ports], + )) + addrs = [("localhost", p) for p in model_port_args.model_tp_ports] + else: + addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)] + + self.model_services = list(executor.map( + lambda args: connect_rpyc_service(*args), addrs)) # Init model def init_model(i): @@ -787,7 +793,7 @@ class ModelTpClient: model_overide_args, ) - self.model_servers = executor.map(init_model, range(self.tp_size)) + self.model_servers = list(executor.map(init_model, range(self.tp_size))) # Wrap functions def async_wrap(func_name): diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index 315ab4163..715b7fd21 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -71,7 +71,11 @@ class ModelConfig: return 1 # For DBRX and MPT - if self.hf_config.model_type in ["dbrx", "mpt"]: + if self.hf_config.model_type in ["mpt"]: + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type in ["dbrx"]: return getattr( self.hf_config.attn_config, "kv_n_heads", diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c98a760c5..4e088a350 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import ( from sglang.srt.managers.controller.manager_single import ( start_controller_process as start_controller_process_single, ) +from sglang.srt.managers.controller.tp_worker import ModelTpService from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -50,9 +51,13 @@ from sglang.srt.utils import ( allocate_init_ports, assert_pkg_version, enable_show_time_cost, + send_addrs_to_rank_0, + receive_addrs, + start_rpyc_service_process, ) from sglang.utils import get_exception_traceback + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg load_chat_template_for_openai_api(server_args.chat_template) # Allocate ports + assert server_args.tp_size % server_args.nnodes == 0 + tp_size_local = server_args.tp_size // server_args.nnodes server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, - server_args.tp_size, + tp_size_local, server_args.dp_size, ) ports = server_args.additional_ports - tp = server_args.tp_size model_port_args = [] for i in range(server_args.dp_size): model_port_args.append( ModelPortArgs( - nccl_port=ports[3 + i * (tp + 1)], - model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)], + nccl_port=ports[3 + i * (tp_size_local + 1)], + model_tp_ips=[None] * tp_size_local, + model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)], ) ) port_args = PortArgs( @@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg model_port_args=model_port_args, ) + # TODO multi-node dp is not supported + assert not (server_args.dp_size > 1 and server_args.node_rank is not None) + if server_args.nnodes > 1: + if server_args.node_rank != 0: + send_addrs_to_rank_0(model_port_args[0], server_args) + else: + receive_addrs(model_port_args[0], server_args) + for i in range(tp_size_local): + start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i]) + if server_args.node_rank != 0: + print("Listen for connections...") + while True: + pass + # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ae74a4390..183afb3b8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -56,6 +56,11 @@ class ServerArgs: disable_regex_jump_forward: bool = False disable_disk_cache: bool = False + # Distributed args + nccl_init_addr: Optional[str] = None + nnodes: int = 1 + node_rank: Optional[int] = None + def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -252,6 +257,24 @@ class ServerArgs: ], ) + # Multi-node distributed serving args + parser.add_argument( + "--nccl-init-addr", + type=str, + help="The nccl init address of multi-node server." + ) + parser.add_argument( + "--nnodes", + type=int, + default=1, + help="Number of nodes" + ) + parser.add_argument( + "--node-rank", + type=int, + help="The node rank." + ) + # Optimization/debug options parser.add_argument( "--enable-flashinfer", @@ -300,6 +323,7 @@ class ServerArgs: @dataclasses.dataclass class ModelPortArgs: nccl_port: int + model_tp_ips: List[str] model_tp_ports: List[int] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8a7f33eb6..f93e0be36 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,11 +1,13 @@ """Common utilities.""" import base64 +import fcntl import logging import multiprocessing import os import random import socket +import struct import time from importlib.metadata import PackageNotFoundError, version from io import BytesIO @@ -369,23 +371,7 @@ def load_image(image_file): return image, image_size -def init_rpyc_service(service: rpyc.Service, port: int): - t = ThreadedServer( - service=service, - port=port, - protocol_config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - t.logger.setLevel(logging.WARN) - t.start() - - -def connect_to_rpyc_service(port, host="localhost"): - time.sleep(1) - +def connect_rpyc_service(host, port): repeat_count = 0 while repeat_count < 20: try: @@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"): }, ) break - except ConnectionRefusedError: + except ConnectionRefusedError as e: time.sleep(1) repeat_count += 1 if repeat_count == 20: - raise RuntimeError("init rpc env error!") + raise RuntimeError(f"Connect rpyc error: {e}") return con.root -def start_rpyc_process(service: rpyc.Service, port: int): - # Return the proxy and the process - proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port)) +def start_rpyc_service(service: rpyc.Service, port: int): + t = ThreadedServer( + service=service, + port=port, + protocol_config={ + "allow_public_attrs": True, + "allow_pickle": True, + "sync_request_timeout": 3600, + }, + ) + t.logger.setLevel(logging.WARN) + t.start() + + +def start_rpyc_service_process(service: rpyc.Service, port: int): + proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port)) proc.start() - proxy = connect_to_rpyc_service(port) - assert proc.is_alive() - return proxy, proc + return proc def suppress_other_loggers(): @@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ) response = await call_next(request) return response + + +def get_ip_address(ifname): + """ + Get the IP address of a network interface. + + :param ifname: Name of the network interface (e.g., 'eth0') + :return: IP address of the network interface + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + ip_address = fcntl.ioctl( + s.fileno(), + 0x8915, # SIOCGIFADDR + struct.pack('256s', bytes(ifname[:15], 'utf-8')) + )[20:24] + return socket.inet_ntoa(ip_address) + + +def send_addrs_to_rank_0(model_port_args, server_args): + assert server_args.node_rank != 0 and server_args.dp_size == 1 + import torch.distributed as dist + + ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) + ip_addr = get_ip_address(ifname) + + num_tp_ports = server_args.tp_size // server_args.nnodes + model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports + ip_addr = [int(x) for x in ip_addr.split(".")] + addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int) + + init_method = f"tcp://{server_args.nccl_init_addr}" + dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) + dist.send(addrs_tensor, dst=0) + print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}") + + dist.barrier() + dist.destroy_process_group() + + +def receive_addrs(model_port_args, server_args): + assert server_args.node_rank == 0 and server_args.dp_size == 1 + import torch.distributed as dist + + ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) + ip_addr = get_ip_address(ifname) + + num_tp_ports = server_args.tp_size // server_args.nnodes + model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports + + init_method = f"tcp://{server_args.nccl_init_addr}" + dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) + + for src_rank in range(1, server_args.nnodes): + tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int) + dist.recv(tensor, src=src_rank) + ip = ".".join([str(x) for x in tensor[:4].tolist()]) + ports = tensor[4:].tolist() + model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports + model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports + print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}") + + dist.barrier() + dist.destroy_process_group() \ No newline at end of file