Use ipc instead of tcp in zmq (#1566)

This commit is contained in:
Lianmin Zheng
2024-10-04 00:45:52 -07:00
committed by GitHub
parent 32eb6e96f2
commit 114bbc8651
9 changed files with 48 additions and 96 deletions

View File

@@ -223,7 +223,6 @@ if __name__ == "__main__":
model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b", model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
port=cur_port, port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
json_model_override_args=json.dumps(model_override_args), json_model_override_args=json.dumps(model_override_args),
tp_size=1, tp_size=1,
) )

View File

@@ -66,9 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
allocate_init_ports,
configure_logger, configure_logger,
kill_child_process, kill_child_process,
suppress_other_loggers, suppress_other_loggers,
@@ -127,11 +126,7 @@ def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
server_args.port, server_args.additional_ports = allocate_init_ports( port_args = PortArgs.init_new(server_args)
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
@@ -143,7 +138,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
nccl_port=server_args.additional_ports[-1], nccl_port=port_args.nccl_ports[0],
server_args=server_args, server_args=server_args,
) )
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")

View File

@@ -59,10 +59,10 @@ class DetokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.recv_from_scheduler = context.socket(zmq.PULL) self.recv_from_scheduler = context.socket(zmq.PULL)
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None

View File

@@ -96,14 +96,10 @@ class Scheduler:
if self.tp_rank == 0: if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind( self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect( self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
else: else:
self.recv_from_tokenizer = self.send_to_detokenizer = None self.recv_from_tokenizer = self.send_to_detokenizer = None

View File

@@ -84,12 +84,10 @@ class TokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
self.send_to_scheduler = context.socket(zmq.PUSH) self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect( self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path

View File

@@ -16,7 +16,6 @@ limitations under the License.
"""Memory pool.""" """Memory pool."""
import logging import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np import numpy as np
@@ -62,9 +61,11 @@ class BaseTokenToKVPool:
self, self,
size: int, size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8 self.store_dtype = torch.uint8
@@ -84,7 +85,7 @@ class BaseTokenToKVPool:
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
return torch.tensor(select_index, dtype=torch.int32, device="cuda") return torch.tensor(select_index, dtype=torch.int32, device=self.device)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
@@ -123,7 +124,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
layer_num: int, layer_num: int,
device: str, device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
@@ -187,7 +188,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
layer_num: int, layer_num: int,
device: str, device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype, device)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.

View File

@@ -24,6 +24,7 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import random
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
@@ -68,9 +69,9 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
add_api_key_middleware, add_api_key_middleware,
allocate_init_ports,
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
is_port_available,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
@@ -302,18 +303,7 @@ def launch_server(
_set_envs_and_config(server_args) _set_envs_and_config(server_args)
# Allocate ports for inter-process communications # Allocate ports for inter-process communications
server_args.port, server_args.additional_ports = allocate_init_ports( port_args = PortArgs.init_new(server_args)
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
ports = server_args.additional_ports
port_args = PortArgs(
tokenizer_port=ports[0],
scheduler_input_port=ports[1],
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
logger.info(f"{server_args=}") logger.info(f"{server_args=}")
# If using model from www.modelscope.cn, first download the model. # If using model from www.modelscope.cn, first download the model.
@@ -499,17 +489,16 @@ class Runtime:
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
# Pre-allocate ports # Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports( for port in range(10000, 40000):
self.server_args.port, if is_port_available(port):
self.server_args.additional_ports, break
self.server_args.dp_size, port += 1
) self.server_args.port = port
self.url = self.server_args.url() self.url = self.server_args.url()
self.generate_url = ( self.generate_url = self.url + "/generate"
f"http://{self.server_args.host}:{self.server_args.port}/generate"
)
# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False) pipe_reader, pipe_writer = mp.Pipe(duplex=False)

View File

@@ -19,9 +19,10 @@ import argparse
import dataclasses import dataclasses
import logging import logging
import random import random
from typing import List, Optional, Union import tempfile
from typing import List, Optional
from sglang.srt.utils import is_hip, is_ipv6 from sglang.srt.utils import is_hip, is_ipv6, is_port_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -46,7 +47,6 @@ class ServerArgs:
# Port # Port
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 30000 port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None
# Memory and scheduling # Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
@@ -134,11 +134,6 @@ class ServerArgs:
else: else:
self.mem_fraction_static = 0.88 self.mem_fraction_static = 0.88
if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
self.additional_ports = []
if self.random_seed is None: if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30) self.random_seed = random.randint(0, 1 << 30)
@@ -199,13 +194,6 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server." "--port", type=int, default=ServerArgs.port, help="The port of the server."
) )
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="The additional ports specified for the server.",
)
parser.add_argument( parser.add_argument(
"--tokenizer-mode", "--tokenizer-mode",
type=str, type=str,
@@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq) # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int tokenizer_ipc_name: str
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq) # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_port: int scheduler_input_ipc_name: str
# The port for detokenizer to receive inputs from scheduler (zmq) # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int detokenizer_ipc_name: str
# The port for nccl initialization for multiple TP groups (torch.dist) # The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int] nccl_ports: List[int]
@classmethod
def init_new(self, server_args):
port = server_args.port + 1
while True:
if is_port_available(port):
break
port += 1
return PortArgs(
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_ports=[port],
)
class LoRAPathAction(argparse.Action): class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(self, parser, namespace, values, option_string=None):

View File

@@ -177,35 +177,6 @@ def is_port_available(port):
return False return False
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
ret_ports = [port]
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1
if port is not None and ret_ports[0] != port:
logger.warning(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:num_ports_needed]
def is_multimodal_model(model_architectures): def is_multimodal_model(model_architectures):
if ( if (
"LlavaLlamaForCausalLM" in model_architectures "LlavaLlamaForCausalLM" in model_architectures