Multi-node Tensor Parallelism (#550)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-06-17 20:41:24 -07:00
committed by GitHub
parent 53a7ebd89a
commit 09593e9bc9
10 changed files with 167 additions and 46 deletions

View File

@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
``` ```
# run synthetic # run synthetic
python3 synthetic_benchmark.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
``` ```
@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
``` ```
# run synthetic # run synthetic
python3 synthetic_benchmark.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000 python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
``` ```

View File

@@ -24,7 +24,7 @@ if __name__ == "__main__":
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
url = f"{args.host}:{args.port}" url = f"{args.host}:{args.port}"
a = random.randint(0, 1 << 20) a = 20
max_new_tokens = 256 max_new_tokens = 256
prompt = f"{a, }" prompt = f"{a, }"

View File

@@ -2,7 +2,8 @@
import argparse import argparse
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@@ -76,8 +76,9 @@ def start_controller_process(
) )
try: try:
tp_size_local = server_args.tp_size // server_args.nnodes
model_client = ModelTpClient( model_client = ModelTpClient(
list(range(server_args.tp_size)), [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
server_args, server_args,
port_args.model_port_args[0], port_args.model_port_args[0],
model_overide_args, model_overide_args,

View File

@@ -246,12 +246,16 @@ class ModelRunner:
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}", distributed_init_method=nccl_init_method
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
@@ -311,7 +315,7 @@ class ModelRunner:
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
@@ -324,7 +328,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0: if self.max_total_num_tokens <= 0:
raise RuntimeError( raise RuntimeError(
"Not enought memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(

View File

@@ -37,7 +37,8 @@ from sglang.srt.utils import (
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
start_rpyc_process, start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
@@ -770,12 +771,17 @@ class ModelTpClient:
else: else:
with ThreadPoolExecutor(self.tp_size) as executor: with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes # Launch model processes
rets = executor.map( if server_args.nnodes == 1:
lambda args: start_rpyc_process(*args), self.procs = list(executor.map(
[(ModelTpService, p) for p in model_port_args.model_tp_ports], lambda args: start_rpyc_service_process(*args),
) [(ModelTpService, p) for p in model_port_args.model_tp_ports],
self.model_services = [x[0] for x in rets] ))
self.procs = [x[1] for x in rets] addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
self.model_services = list(executor.map(
lambda args: connect_rpyc_service(*args), addrs))
# Init model # Init model
def init_model(i): def init_model(i):
@@ -787,7 +793,7 @@ class ModelTpClient:
model_overide_args, model_overide_args,
) )
self.model_servers = executor.map(init_model, range(self.tp_size)) self.model_servers = list(executor.map(init_model, range(self.tp_size)))
# Wrap functions # Wrap functions
def async_wrap(func_name): def async_wrap(func_name):

View File

@@ -71,7 +71,11 @@ class ModelConfig:
return 1 return 1
# For DBRX and MPT # For DBRX and MPT
if self.hf_config.model_type in ["dbrx", "mpt"]: if self.hf_config.model_type in ["mpt"]:
if "kv_n_heads" in self.hf_config.attn_config:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type in ["dbrx"]:
return getattr( return getattr(
self.hf_config.attn_config, self.hf_config.attn_config,
"kv_n_heads", "kv_n_heads",

View File

@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.controller.tp_worker import ModelTpService
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -50,9 +51,13 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
send_addrs_to_rank_0,
receive_addrs,
start_rpyc_service_process,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
# Allocate ports # Allocate ports
assert server_args.tp_size % server_args.nnodes == 0
tp_size_local = server_args.tp_size // server_args.nnodes
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.port,
server_args.additional_ports, server_args.additional_ports,
server_args.tp_size, tp_size_local,
server_args.dp_size, server_args.dp_size,
) )
ports = server_args.additional_ports ports = server_args.additional_ports
tp = server_args.tp_size
model_port_args = [] model_port_args = []
for i in range(server_args.dp_size): for i in range(server_args.dp_size):
model_port_args.append( model_port_args.append(
ModelPortArgs( ModelPortArgs(
nccl_port=ports[3 + i * (tp + 1)], nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)], model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
) )
) )
port_args = PortArgs( port_args = PortArgs(
@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args=model_port_args, model_port_args=model_port_args,
) )
# TODO multi-node dp is not supported
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
if server_args.nnodes > 1:
if server_args.node_rank != 0:
send_addrs_to_rank_0(model_port_args[0], server_args)
else:
receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local):
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
if server_args.node_rank != 0:
print("Listen for connections...")
while True:
pass
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)

View File

@@ -56,6 +56,11 @@ class ServerArgs:
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
@@ -252,6 +257,24 @@ class ServerArgs:
], ],
) )
# Multi-node distributed serving args
parser.add_argument(
"--nccl-init-addr",
type=str,
help="The nccl init address of multi-node server."
)
parser.add_argument(
"--nnodes",
type=int,
default=1,
help="Number of nodes"
)
parser.add_argument(
"--node-rank",
type=int,
help="The node rank."
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--enable-flashinfer", "--enable-flashinfer",
@@ -300,6 +323,7 @@ class ServerArgs:
@dataclasses.dataclass @dataclasses.dataclass
class ModelPortArgs: class ModelPortArgs:
nccl_port: int nccl_port: int
model_tp_ips: List[str]
model_tp_ports: List[int] model_tp_ports: List[int]

View File

@@ -1,11 +1,13 @@
"""Common utilities.""" """Common utilities."""
import base64 import base64
import fcntl
import logging import logging
import multiprocessing import multiprocessing
import os import os
import random import random
import socket import socket
import struct
import time import time
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
@@ -369,23 +371,7 @@ def load_image(image_file):
return image, image_size return image, image_size
def init_rpyc_service(service: rpyc.Service, port: int): def connect_rpyc_service(host, port):
t = ThreadedServer(
service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def connect_to_rpyc_service(port, host="localhost"):
time.sleep(1)
repeat_count = 0 repeat_count = 0
while repeat_count < 20: while repeat_count < 20:
try: try:
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
}, },
) )
break break
except ConnectionRefusedError: except ConnectionRefusedError as e:
time.sleep(1) time.sleep(1)
repeat_count += 1 repeat_count += 1
if repeat_count == 20: if repeat_count == 20:
raise RuntimeError("init rpc env error!") raise RuntimeError(f"Connect rpyc error: {e}")
return con.root return con.root
def start_rpyc_process(service: rpyc.Service, port: int): def start_rpyc_service(service: rpyc.Service, port: int):
# Return the proxy and the process t = ThreadedServer(
proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port)) service=service,
port=port,
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
t.start()
def start_rpyc_service_process(service: rpyc.Service, port: int):
proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
proc.start() proc.start()
proxy = connect_to_rpyc_service(port) return proc
assert proc.is_alive()
return proxy, proc
def suppress_other_loggers(): def suppress_other_loggers():
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
) )
response = await call_next(request) response = await call_next(request)
return response return response
def get_ip_address(ifname):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
)[20:24]
return socket.inet_ntoa(ip_address)
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
dist.send(addrs_tensor, dst=0)
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
dist.barrier()
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
dist.destroy_process_group()