Use ipc instead of tcp in zmq (#1566)
This commit is contained in:
@@ -223,7 +223,6 @@ if __name__ == "__main__":
|
||||
model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
|
||||
tokenizer_path=tokenizer_path,
|
||||
port=cur_port,
|
||||
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
|
||||
json_model_override_args=json.dumps(model_override_args),
|
||||
tp_size=1,
|
||||
)
|
||||
|
||||
@@ -66,9 +66,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server import _set_envs_and_config
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
allocate_init_ports,
|
||||
configure_logger,
|
||||
kill_child_process,
|
||||
suppress_other_loggers,
|
||||
@@ -127,11 +126,7 @@ def load_model(server_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
server_args.dp_size,
|
||||
)
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
server_args.trust_remote_code,
|
||||
@@ -143,7 +138,7 @@ def load_model(server_args, tp_rank):
|
||||
gpu_id=tp_rank,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=server_args.tp_size,
|
||||
nccl_port=server_args.additional_ports[-1],
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
server_args=server_args,
|
||||
)
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
|
||||
@@ -59,10 +59,10 @@ class DetokenizerManager:
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.recv_from_scheduler = context.socket(zmq.PULL)
|
||||
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
|
||||
|
||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
|
||||
@@ -96,14 +96,10 @@ class Scheduler:
|
||||
|
||||
if self.tp_rank == 0:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(
|
||||
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||
)
|
||||
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
|
||||
else:
|
||||
self.recv_from_tokenizer = self.send_to_detokenizer = None
|
||||
|
||||
|
||||
@@ -84,12 +84,10 @@ class TokenizerManager:
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
|
||||
|
||||
self.send_to_scheduler = context.socket(zmq.PUSH)
|
||||
self.send_to_scheduler.connect(
|
||||
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
|
||||
)
|
||||
self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
|
||||
@@ -16,7 +16,6 @@ limitations under the License.
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -62,9 +61,11 @@ class BaseTokenToKVPool:
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
@@ -84,7 +85,7 @@ class BaseTokenToKVPool:
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return torch.tensor(select_index, dtype=torch.int32, device="cuda")
|
||||
return torch.tensor(select_index, dtype=torch.int32, device=self.device)
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
||||
@@ -123,7 +124,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(size, dtype)
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
@@ -187,7 +188,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(size, dtype)
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
|
||||
@@ -24,6 +24,7 @@ import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -68,9 +69,9 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
is_port_available,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
@@ -302,18 +303,7 @@ def launch_server(
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port,
|
||||
server_args.additional_ports,
|
||||
server_args.dp_size,
|
||||
)
|
||||
ports = server_args.additional_ports
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
scheduler_input_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# If using model from www.modelscope.cn, first download the model.
|
||||
@@ -499,17 +489,16 @@ class Runtime:
|
||||
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
|
||||
# Pre-allocate ports
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port,
|
||||
self.server_args.additional_ports,
|
||||
self.server_args.dp_size,
|
||||
)
|
||||
for port in range(10000, 40000):
|
||||
if is_port_available(port):
|
||||
break
|
||||
port += 1
|
||||
self.server_args.port = port
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = (
|
||||
f"http://{self.server_args.host}:{self.server_args.port}/generate"
|
||||
)
|
||||
self.generate_url = self.url + "/generate"
|
||||
|
||||
# NOTE: We store pid instead of proc to fix some issues during __delete__
|
||||
self.pid = None
|
||||
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
||||
|
||||
|
||||
@@ -19,9 +19,10 @@ import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
import tempfile
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.utils import is_hip, is_ipv6
|
||||
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,7 +47,6 @@ class ServerArgs:
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
additional_ports: Optional[Union[List[int], int]] = None
|
||||
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
@@ -134,11 +134,6 @@ class ServerArgs:
|
||||
else:
|
||||
self.mem_fraction_static = 0.88
|
||||
|
||||
if isinstance(self.additional_ports, int):
|
||||
self.additional_ports = [self.additional_ports]
|
||||
elif self.additional_ports is None:
|
||||
self.additional_ports = []
|
||||
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
@@ -199,13 +194,6 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--additional-ports",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="The additional ports specified for the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
@@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
# The port for tokenizer to receive inputs from detokenizer (zmq)
|
||||
tokenizer_port: int
|
||||
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
||||
scheduler_input_port: int
|
||||
# The port for detokenizer to receive inputs from scheduler (zmq)
|
||||
detokenizer_port: int
|
||||
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
||||
tokenizer_ipc_name: str
|
||||
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
||||
scheduler_input_ipc_name: str
|
||||
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
||||
detokenizer_ipc_name: str
|
||||
|
||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
||||
nccl_ports: List[int]
|
||||
|
||||
@classmethod
|
||||
def init_new(self, server_args):
|
||||
port = server_args.port + 1
|
||||
while True:
|
||||
if is_port_available(port):
|
||||
break
|
||||
port += 1
|
||||
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
nccl_ports=[port],
|
||||
)
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
|
||||
@@ -177,35 +177,6 @@ def is_port_available(port):
|
||||
return False
|
||||
|
||||
|
||||
def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
"""Allocate ports for all connections."""
|
||||
if additional_ports:
|
||||
ret_ports = [port] + additional_ports
|
||||
else:
|
||||
ret_ports = [port]
|
||||
|
||||
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||
|
||||
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
||||
num_ports_needed = 4 + dp_size
|
||||
while len(ret_ports) < num_ports_needed:
|
||||
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||
ret_ports.append(cur_port)
|
||||
cur_port += 1
|
||||
|
||||
if port is not None and ret_ports[0] != port:
|
||||
logger.warning(
|
||||
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
||||
)
|
||||
|
||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
|
||||
Reference in New Issue
Block a user