Improve multi-node stability (#1171)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
"""Launch the inference server."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from sglang.srt.server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_child_process
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -11,4 +13,9 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args)
|
||||
try:
|
||||
launch_server(server_args)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
|
||||
@@ -233,6 +233,8 @@ class TiktokenTokenizer:
|
||||
}
|
||||
assert tok_dict["word_split"] == "V1"
|
||||
|
||||
default_allowed_special = None
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
||||
@@ -246,14 +248,18 @@ class TiktokenTokenizer:
|
||||
for bytes_list in tok_dict["default_allowed_special"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
default_allowed_special = None
|
||||
if "vocab_size" in tok_dict:
|
||||
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
||||
|
||||
PAD = "<|pad|>"
|
||||
EOS = "<|eos|>"
|
||||
SEP = "<|separator|>"
|
||||
|
||||
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
||||
|
||||
tokenizer = tiktoken.Encoding(**kwargs)
|
||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||
tokenizer._default_allowed_special |= {"<|separator|>"}
|
||||
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
||||
|
||||
def encode_patched(
|
||||
self,
|
||||
@@ -270,14 +276,14 @@ class TiktokenTokenizer:
|
||||
self,
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
disallowed_special=(),
|
||||
)
|
||||
|
||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||
|
||||
# Convert to HF interface
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
||||
self.eos_token_id = tokenizer._special_tokens[EOS]
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
self.chat_template = Template(
|
||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||
|
||||
@@ -212,6 +212,4 @@ def start_controller_process(
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
||||
finally:
|
||||
for w in controller.workers:
|
||||
os.kill(w.proc.pid, 9)
|
||||
kill_parent_process()
|
||||
|
||||
@@ -167,6 +167,4 @@ def start_controller_process(
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
for t in controller.tp_procs:
|
||||
os.kill(t.pid, 9)
|
||||
kill_parent_process()
|
||||
|
||||
@@ -16,7 +16,6 @@ limitations under the License.
|
||||
"""Meta data for requests and batches"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -270,7 +269,7 @@ class Req:
|
||||
|
||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||
# TODO(lsyin): fix token fusion
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||
)
|
||||
return False
|
||||
@@ -753,7 +752,7 @@ class ScheduleBatch:
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
|
||||
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
|
||||
def sample(self, logits: torch.Tensor):
|
||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
@@ -791,7 +790,7 @@ class ScheduleBatch:
|
||||
)
|
||||
|
||||
if not torch.all(success):
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
@@ -808,16 +807,6 @@ class ScheduleBatch:
|
||||
|
||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||
|
||||
if is_multi_node_tp:
|
||||
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
|
||||
# that can cause the TP workers to generate different tokens, so we need to
|
||||
# sync here
|
||||
torch.distributed.all_reduce(
|
||||
batch_next_token_ids,
|
||||
op=dist.ReduceOp.MIN,
|
||||
group=get_tensor_model_parallel_group().device_group,
|
||||
)
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
|
||||
@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
|
||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||
try:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError:
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Sampling error: {e}")
|
||||
batch_next_token_ids = torch.zeros(
|
||||
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
||||
)
|
||||
|
||||
@@ -133,6 +133,13 @@ class ModelTpServer:
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
)
|
||||
|
||||
# Sync random seed
|
||||
server_args.random_seed = broadcast_recv_input(
|
||||
[server_args.random_seed],
|
||||
self.tp_rank,
|
||||
self.model_runner.tp_group.cpu_group,
|
||||
)[0]
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
@@ -474,9 +481,7 @@ class ModelTpServer:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids = batch.sample(
|
||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||
)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -636,9 +641,7 @@ class ModelTpServer:
|
||||
|
||||
# Forward and sample the next tokens
|
||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||
next_token_ids = batch.sample(
|
||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
||||
)
|
||||
next_token_ids = batch.sample(output.next_token_logits)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if output.next_token_logprobs is not None:
|
||||
@@ -879,6 +882,7 @@ def broadcast_recv_input(
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
return data
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
|
||||
@@ -84,13 +84,20 @@ def set_torch_compile_config():
|
||||
|
||||
|
||||
class CudaGraphRunner:
|
||||
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner,
|
||||
max_batch_size_to_capture: int,
|
||||
use_torch_compile: bool,
|
||||
disable_padding: bool,
|
||||
):
|
||||
self.model_runner = model_runner
|
||||
self.graphs = {}
|
||||
self.input_buffers = {}
|
||||
self.output_buffers = {}
|
||||
self.flashinfer_handlers = {}
|
||||
self.graph_memory_pool = None
|
||||
self.disable_padding = disable_padding
|
||||
|
||||
# Common inputs
|
||||
self.max_bs = max_batch_size_to_capture
|
||||
@@ -142,7 +149,10 @@ class CudaGraphRunner:
|
||||
set_torch_compile_config()
|
||||
|
||||
def can_run(self, batch_size):
|
||||
return batch_size <= self.max_bs
|
||||
if self.disable_padding:
|
||||
return batch_size in self.graphs
|
||||
else:
|
||||
return batch_size <= self.max_bs
|
||||
|
||||
def capture(self, batch_size_list):
|
||||
self.batch_size_list = batch_size_list
|
||||
|
||||
@@ -465,6 +465,7 @@ class ModelRunner:
|
||||
self,
|
||||
max_batch_size_to_capture=max(batch_size_list),
|
||||
use_torch_compile=self.server_args.enable_torch_compile,
|
||||
disable_padding=self.server_args.disable_cuda_graph_padding,
|
||||
)
|
||||
try:
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
|
||||
@@ -24,7 +24,6 @@ import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -301,27 +300,29 @@ def launch_server(
|
||||
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
||||
|
||||
# Launch processes for multi-node tensor parallelism
|
||||
if server_args.nnodes > 1:
|
||||
if server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [
|
||||
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
||||
]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
while True:
|
||||
pass
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
try:
|
||||
for p in procs:
|
||||
p.join()
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||
@@ -356,15 +357,11 @@ def launch_server(
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
print(
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}",
|
||||
flush=True,
|
||||
raise RuntimeError(
|
||||
"Initialization failed. "
|
||||
f"controller_init_state: {controller_init_state}, "
|
||||
f"detoken_init_state: {detoken_init_state}"
|
||||
)
|
||||
print(
|
||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||
flush=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
# Add api key authorization
|
||||
@@ -373,12 +370,12 @@ def launch_server(
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
||||
)
|
||||
t.start()
|
||||
|
||||
# Listen for requests
|
||||
try:
|
||||
# Listen for requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
||||
sys.exit(1)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
return
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
||||
sys.exit(1)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
return
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
|
||||
@@ -79,6 +79,7 @@ class ServerArgs:
|
||||
disable_radix_cache: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_cuda_graph: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
@@ -393,6 +394,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Disable cuda graph.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-cuda-graph-padding",
|
||||
action="store_true",
|
||||
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-disk-cache",
|
||||
action="store_true",
|
||||
|
||||
@@ -369,14 +369,11 @@ def kill_parent_process():
|
||||
"""Kill the parent process and all children of the parent process."""
|
||||
current_process = psutil.Process()
|
||||
parent_process = current_process.parent()
|
||||
children = parent_process.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid != current_process.pid:
|
||||
os.kill(child.pid, 9)
|
||||
os.kill(parent_process.pid, 9)
|
||||
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
||||
|
||||
|
||||
def kill_child_process(pid, including_parent=True):
|
||||
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
||||
"""Kill the process and all its children process."""
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
|
||||
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid == skip_pid:
|
||||
continue
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
|
||||
Reference in New Issue
Block a user