Improve tensor parallel performance (#625)
Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
||||||
```
|
```
|
||||||
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
||||||
|
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
|
||||||
|
```
|
||||||
|
# Node 0
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
|
||||||
|
|
||||||
|
# Node 1
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
|
||||||
|
```
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
- Llama
|
- Llama
|
||||||
|
|||||||
@@ -96,8 +96,11 @@ def run_one_batch_size(bs):
|
|||||||
ret = response.json()
|
ret = response.json()
|
||||||
print(ret)
|
print(ret)
|
||||||
|
|
||||||
|
input_len = args.input_len if args.input_len else 1
|
||||||
|
output_len = max_new_tokens
|
||||||
|
|
||||||
output_throughput = bs * max_new_tokens / latency
|
output_throughput = bs * max_new_tokens / latency
|
||||||
overall_throughput = bs * (args.input_len + max_new_tokens) / latency
|
overall_throughput = bs * (input_len + output_len) / latency
|
||||||
print(f"latency: {latency:.2f} s")
|
print(f"latency: {latency:.2f} s")
|
||||||
print(f"decode throughput: {output_throughput:.2f} token/s")
|
print(f"decode throughput: {output_throughput:.2f} token/s")
|
||||||
print(f"overall throughput: {overall_throughput:.2f} token/s")
|
print(f"overall throughput: {overall_throughput:.2f} token/s")
|
||||||
|
|||||||
@@ -312,6 +312,9 @@ def main(args: argparse.Namespace):
|
|||||||
np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
|
np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#latencies = [round(latency, 2) for _, _, latency in REQUEST_LATENCY]
|
||||||
|
#print(latencies)
|
||||||
|
|
||||||
print(f"Total time: {benchmark_time:.2f} s")
|
print(f"Total time: {benchmark_time:.2f} s")
|
||||||
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||||
print(f"Decoding throughput: {decoding_throughput:.2f} token/s")
|
print(f"Decoding throughput: {decoding_throughput:.2f} token/s")
|
||||||
|
|||||||
@@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
- `backend`: Various backends for the language interpreter.
|
- `backend`: Various backends for the language interpreter.
|
||||||
- `lang`: The frontend language.
|
- `lang`: The frontend language.
|
||||||
- `srt`: The runtime for running local models.
|
- `srt`: The serving engine for running local models. (SRT = SGLang Runtime).
|
||||||
- `test`: Test utilities.
|
- `test`: Test utilities.
|
||||||
- `api.py`: Public API.
|
- `api.py`: Public API.
|
||||||
- `bench_latency.py`: Benchmark utilities.
|
- `bench_latency.py`: Benchmark utilities.
|
||||||
- `global_config.py`: The global configs and constants.
|
- `global_config.py`: The global configs and constants.
|
||||||
- `launch_server.py`: The entry point of launching local server.
|
- `launch_server.py`: The entry point of launching local server.
|
||||||
- `utils.py`: Common utilities.
|
- `utils.py`: Common utilities.
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Controller:
|
class Controller:
|
||||||
|
"""A controller that manages multiple data parallel workers."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
load_balance_method: str,
|
load_balance_method: str,
|
||||||
@@ -183,9 +185,11 @@ def start_controller_process(
|
|||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
pipe_writer.send("init ok")
|
pipe_writer.send("init ok")
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
||||||
|
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.create_task(controller.loop_for_recv_requests())
|
loop.create_task(controller.loop_for_recv_requests())
|
||||||
loop.run_until_complete(controller.loop_for_forward())
|
loop.run_until_complete(controller.loop_for_forward())
|
||||||
|
|||||||
@@ -1,28 +1,104 @@
|
|||||||
"""A controller that manages a group of tensor parallel workers."""
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
|
|
||||||
import asyncio
|
import multiprocessing
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import uvloop
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
||||||
from sglang.srt.utils import kill_parent_process
|
from sglang.srt.utils import kill_parent_process
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
||||||
|
|
||||||
logger = logging.getLogger("srt.controller")
|
logger = logging.getLogger("srt.controller")
|
||||||
|
|
||||||
|
|
||||||
|
def run_tp_server(
|
||||||
|
gpu_id: int,
|
||||||
|
tp_rank: int,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_port_args: ModelPortArgs,
|
||||||
|
model_overide_args: dict,
|
||||||
|
):
|
||||||
|
"""Run a tp server."""
|
||||||
|
try:
|
||||||
|
model_server = ModelTpServer(
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
server_args,
|
||||||
|
model_port_args,
|
||||||
|
model_overide_args,
|
||||||
|
)
|
||||||
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||||
|
model_server.exposed_step(recv_reqs)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
||||||
|
model_port_args, model_overide_args):
|
||||||
|
"""Launch multiple tp servers."""
|
||||||
|
procs = []
|
||||||
|
for i in tp_rank_range:
|
||||||
|
proc = multiprocessing.Process(target=run_tp_server, args=(
|
||||||
|
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
||||||
|
))
|
||||||
|
proc.start()
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
|
return procs
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_recv_input(data, rank, dist_group):
|
||||||
|
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
if len(data) == 0:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
else:
|
||||||
|
serialized_data = pickle.dumps(data)
|
||||||
|
size = len(serialized_data)
|
||||||
|
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||||
|
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||||
|
|
||||||
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
|
else:
|
||||||
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
size = tensor_size.item()
|
||||||
|
|
||||||
|
if size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||||
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
|
|
||||||
|
serialized_data = bytes(tensor_data.tolist())
|
||||||
|
data = pickle.loads(serialized_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ControllerSingle:
|
class ControllerSingle:
|
||||||
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
|
|
||||||
|
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
|
||||||
|
# Parse args
|
||||||
|
self.server_args = server_args
|
||||||
|
|
||||||
# Init communication
|
# Init communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.Context(2)
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
||||||
|
|
||||||
@@ -31,44 +107,52 @@ class ControllerSingle:
|
|||||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init status
|
# Init model server
|
||||||
self.model_client = model_client
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
self.recv_reqs = []
|
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||||
|
|
||||||
# Init some configs
|
# Launch other tp ranks
|
||||||
self.request_dependency_delay = global_config.request_dependency_delay
|
if tp_size_local > 1:
|
||||||
|
tp_rank_range = range(1, tp_size_local)
|
||||||
|
self.tp_procs = launch_tp_servers(
|
||||||
|
gpu_ids, tp_rank_range, server_args,
|
||||||
|
port_args.model_port_args[0], model_overide_args)
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
# Launch tp rank 0
|
||||||
|
self.tp_server = ModelTpServer(
|
||||||
|
gpu_ids[0],
|
||||||
|
0,
|
||||||
|
server_args,
|
||||||
|
port_args.model_port_args[0],
|
||||||
|
model_overide_args,
|
||||||
|
)
|
||||||
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
|
def loop_for_forward(self):
|
||||||
while True:
|
while True:
|
||||||
next_step_input = list(self.recv_reqs)
|
recv_reqs = self.recv_requests()
|
||||||
self.recv_reqs = []
|
|
||||||
out_pyobjs = await self.model_client.step(next_step_input)
|
if self.server_args.tp_size > 1:
|
||||||
|
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
||||||
|
|
||||||
|
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||||
|
|
||||||
for obj in out_pyobjs:
|
for obj in out_pyobjs:
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|
||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
def recv_requests(self):
|
||||||
slept = False
|
recv_reqs = []
|
||||||
if len(out_pyobjs) != 0:
|
|
||||||
has_finished = any(
|
|
||||||
[obj.finished_reason is not None for obj in out_pyobjs]
|
|
||||||
)
|
|
||||||
if has_finished:
|
|
||||||
if self.request_dependency_delay > 0:
|
|
||||||
slept = True
|
|
||||||
await asyncio.sleep(self.request_dependency_delay)
|
|
||||||
|
|
||||||
if not slept:
|
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
|
||||||
|
|
||||||
async def loop_for_recv_requests(self):
|
|
||||||
while True:
|
while True:
|
||||||
recv_req = await self.recv_from_tokenizer.recv_pyobj()
|
try:
|
||||||
self.recv_reqs.append(recv_req)
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
|
recv_reqs.append(recv_req)
|
||||||
|
except zmq.ZMQError:
|
||||||
|
break
|
||||||
|
return recv_reqs
|
||||||
|
|
||||||
|
|
||||||
def start_controller_process(
|
def start_controller_process(
|
||||||
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
|
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
||||||
):
|
):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, server_args.log_level.upper()),
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
@@ -76,27 +160,18 @@ def start_controller_process(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
||||||
model_client = ModelTpClient(
|
|
||||||
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
|
||||||
server_args,
|
|
||||||
port_args.model_port_args[0],
|
|
||||||
model_overide_args,
|
|
||||||
)
|
|
||||||
controller = ControllerSingle(model_client, port_args)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
pipe_writer.send("init ok")
|
pipe_writer.send("init ok")
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.create_task(controller.loop_for_recv_requests())
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(controller.loop_for_forward())
|
controller.loop_for_forward()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||||
finally:
|
finally:
|
||||||
|
for t in controller.tp_procs:
|
||||||
|
os.kill(t.pid, 9)
|
||||||
kill_parent_process()
|
kill_parent_process()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.config import DeviceConfig, LoadConfig
|
from vllm.config import DeviceConfig, LoadConfig
|
||||||
from vllm.config import ModelConfig as VllmModelConfig
|
from vllm.config import ModelConfig as VllmModelConfig
|
||||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
@@ -75,6 +75,7 @@ class ModelRunner:
|
|||||||
distributed_init_method=nccl_init_method,
|
distributed_init_method=nccl_init_method,
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
total_gpu_memory = get_available_gpu_memory(
|
total_gpu_memory = get_available_gpu_memory(
|
||||||
self.gpu_id, distributed=self.tp_size > 1
|
self.gpu_id, distributed=self.tp_size > 1
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class ModelTpServer:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_port_args: ModelPortArgs,
|
model_port_args: ModelPortArgs,
|
||||||
model_overide_args,
|
model_overide_args: dict,
|
||||||
):
|
):
|
||||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
@@ -178,7 +178,7 @@ class ModelTpServer:
|
|||||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if self.tp_size * self.dp_size != 1:
|
if not isinstance(recv_reqs, list):
|
||||||
recv_reqs = obtain(recv_reqs)
|
recv_reqs = obtain(recv_reqs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -206,11 +206,11 @@ class ModelTpServer:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
new_batch = self.get_new_fill_batch()
|
new_batch = self.get_new_prefill_batch()
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run a new fill batch
|
# Run a new prefill batch
|
||||||
self.forward_fill_batch(new_batch)
|
self.forward_prefill_batch(new_batch)
|
||||||
self.cache_filled_batch(new_batch)
|
self.cache_filled_batch(new_batch)
|
||||||
|
|
||||||
if not new_batch.is_empty():
|
if not new_batch.is_empty():
|
||||||
@@ -219,7 +219,7 @@ class ModelTpServer:
|
|||||||
else:
|
else:
|
||||||
self.running_batch.merge(new_batch)
|
self.running_batch.merge(new_batch)
|
||||||
else:
|
else:
|
||||||
# Run decode batch
|
# Run a decode batch
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
# Run a few decode batches continuously for reducing overhead
|
# Run a few decode batches continuously for reducing overhead
|
||||||
for _ in range(global_config.num_continue_decode_steps):
|
for _ in range(global_config.num_continue_decode_steps):
|
||||||
@@ -312,7 +312,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.forward_queue.append(req)
|
self.forward_queue.append(req)
|
||||||
|
|
||||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||||
running_bs = (
|
running_bs = (
|
||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
)
|
)
|
||||||
@@ -436,7 +436,7 @@ class ModelTpServer:
|
|||||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def forward_fill_batch(self, batch: Batch):
|
def forward_prefill_batch(self, batch: Batch):
|
||||||
# Build batch tensors
|
# Build batch tensors
|
||||||
batch.prepare_for_extend(
|
batch.prepare_for_extend(
|
||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
@@ -746,8 +746,8 @@ class ModelTpClient:
|
|||||||
# Init model
|
# Init model
|
||||||
assert len(gpu_ids) == 1
|
assert len(gpu_ids) == 1
|
||||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||||
0,
|
|
||||||
gpu_ids[0],
|
gpu_ids[0],
|
||||||
|
0,
|
||||||
server_args,
|
server_args,
|
||||||
model_port_args,
|
model_port_args,
|
||||||
model_overide_args,
|
model_overide_args,
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ from sglang.srt.managers.controller.manager_multi import (
|
|||||||
start_controller_process as start_controller_process_multi,
|
start_controller_process as start_controller_process_multi,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.manager_single import (
|
from sglang.srt.managers.controller.manager_single import (
|
||||||
|
launch_tp_servers,
|
||||||
start_controller_process as start_controller_process_single,
|
start_controller_process as start_controller_process_single,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpService
|
|
||||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
|
|||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
receive_addrs,
|
receive_addrs,
|
||||||
send_addrs_to_rank_0,
|
send_addrs_to_rank_0,
|
||||||
start_rpyc_service_process,
|
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -192,21 +191,17 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
model_port_args=model_port_args,
|
model_port_args=model_port_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO multi-node dp is not supported
|
# Handle multi-node tp
|
||||||
assert not (server_args.dp_size > 1 and server_args.node_rank is not None)
|
|
||||||
if server_args.nnodes > 1:
|
if server_args.nnodes > 1:
|
||||||
|
assert server_args.dp_size == 1, "Multi-node dp is not supported."
|
||||||
|
|
||||||
if server_args.node_rank != 0:
|
if server_args.node_rank != 0:
|
||||||
send_addrs_to_rank_0(model_port_args[0], server_args)
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
else:
|
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||||
receive_addrs(model_port_args[0], server_args)
|
tp_rank_range = list(range(server_args.node_rank * tp_size_local,
|
||||||
for i in range(tp_size_local):
|
(server_args.node_rank + 1) * tp_size_local))
|
||||||
start_rpyc_service_process(
|
procs = launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
||||||
ModelTpService, model_port_args[0].model_tp_ports[i]
|
port_args.model_port_args[0], model_overide_args)
|
||||||
)
|
|
||||||
if server_args.node_rank != 0:
|
|
||||||
logger.info(
|
|
||||||
f"[node_rank={server_args.node_rank}]: Listen for connections..."
|
|
||||||
)
|
|
||||||
while True:
|
while True:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -67,10 +67,12 @@ class ServerArgs:
|
|||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
if self.tp_size >= 8:
|
if self.tp_size >= 16:
|
||||||
|
self.mem_fraction_static = 0.74
|
||||||
|
elif self.tp_size >= 8:
|
||||||
self.mem_fraction_static = 0.78
|
self.mem_fraction_static = 0.78
|
||||||
elif self.tp_size >= 4:
|
elif self.tp_size >= 4:
|
||||||
self.mem_fraction_static = 0.80
|
self.mem_fraction_static = 0.82
|
||||||
elif self.tp_size >= 2:
|
elif self.tp_size >= 2:
|
||||||
self.mem_fraction_static = 0.85
|
self.mem_fraction_static = 0.85
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user