From bea2bb9eeae6cf6f1bdfbb6aaaae2d91adea7bac Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 20 Aug 2024 22:35:05 -0700 Subject: [PATCH] Improve multi-node stability (#1171) --- python/sglang/launch_server.py | 9 ++- python/sglang/srt/hf_transformers_utils.py | 16 ++-- .../sglang/srt/managers/controller_multi.py | 2 - .../sglang/srt/managers/controller_single.py | 2 - python/sglang/srt/managers/schedule_batch.py | 20 ++--- python/sglang/srt/managers/tp_worker.py | 16 ++-- .../srt/model_executor/cuda_graph_runner.py | 14 +++- .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server.py | 73 +++++++++---------- python/sglang/srt/server_args.py | 6 ++ python/sglang/srt/utils.py | 11 ++- 11 files changed, 94 insertions(+), 76 deletions(-) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 91dc0dc4e..1df64e848 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,9 +1,11 @@ """Launch the inference server.""" import argparse +import os from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_child_process if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -11,4 +13,9 @@ if __name__ == "__main__": args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args) + try: + launch_server(server_args) + except Exception as e: + raise e + finally: + kill_child_process(os.getpid(), including_parent=False) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index b3576b47b..525d29543 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -233,6 +233,8 @@ class TiktokenTokenizer: } assert tok_dict["word_split"] == "V1" + default_allowed_special = None + kwargs = { "name": name, "pat_str": tok_dict.get("pat_str", PAT_STR_B), @@ -246,14 +248,18 @@ class TiktokenTokenizer: for bytes_list in tok_dict["default_allowed_special"] ] ) - else: - default_allowed_special = None if "vocab_size" in tok_dict: kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] + PAD = "<|pad|>" + EOS = "<|eos|>" + SEP = "<|separator|>" + + DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP} + tokenizer = tiktoken.Encoding(**kwargs) tokenizer._default_allowed_special = default_allowed_special or set() - tokenizer._default_allowed_special |= {"<|separator|>"} + tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS def encode_patched( self, @@ -270,14 +276,14 @@ class TiktokenTokenizer: self, text, allowed_special=allowed_special, - disallowed_special=disallowed_special, + disallowed_special=(), ) tokenizer.encode = functools.partial(encode_patched, tokenizer) # Convert to HF interface self.tokenizer = tokenizer - self.eos_token_id = tokenizer._special_tokens["<|eos|>"] + self.eos_token_id = tokenizer._special_tokens[EOS] self.vocab_size = tokenizer.n_vocab self.chat_template = Template( "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index dcd984e0f..58c4f4484 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -212,6 +212,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerMulti:\n" + get_exception_traceback()) finally: - for w in controller.workers: - os.kill(w.proc.pid, 9) kill_parent_process() diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 415325b13..a3402c62f 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -167,6 +167,4 @@ def start_controller_process( except Exception: logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) finally: - for t in controller.tp_procs: - os.kill(t.pid, 9) kill_parent_process() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 14374e580..1437d0e6c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -16,7 +16,6 @@ limitations under the License. """Meta data for requests and batches""" import logging -import warnings from dataclasses import dataclass from typing import List, Optional, Union @@ -270,7 +269,7 @@ class Req: if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: # TODO(lsyin): fix token fusion - logging.warning( + logger.warning( "Token fusion between input and output, try to avoid this by removing the space at the end of the input." ) return False @@ -753,7 +752,7 @@ class ScheduleBatch: ) self.logit_bias = torch.concat([self.logit_bias, other.logit_bias]) - def sample(self, logits: torch.Tensor, is_multi_node_tp=False): + def sample(self, logits: torch.Tensor): # TODO(lsyin): move this into a part of layer and run with CUDA Graph # Post process logits logits = logits.contiguous() @@ -791,7 +790,7 @@ class ScheduleBatch: ) if not torch.all(success): - logging.warning("Sampling failed, fallback to top_k=1 strategy") + logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}") probs = probs.masked_fill(torch.isnan(probs), 0.0) argmax_ids = torch.argmax(probs, dim=-1) batch_next_token_ids = torch.where( @@ -808,16 +807,6 @@ class ScheduleBatch: self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids) - if is_multi_node_tp: - # If the tensor parallelism spans across multiple nodes, there is some indeterminism - # that can cause the TP workers to generate different tokens, so we need to - # sync here - torch.distributed.all_reduce( - batch_next_token_ids, - op=dist.ReduceOp.MIN, - group=get_tensor_model_parallel_group().device_group, - ) - return batch_next_token_ids @@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch( probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) try: sampled_index = torch.multinomial(probs_sort, num_samples=1) - except RuntimeError: + except RuntimeError as e: + logger.warning(f"Sampling error: {e}") batch_next_token_ids = torch.zeros( (probs_sort.shape[0],), dtype=torch.int32, device=probs.device ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bd2e3812..8772a4abb 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -133,6 +133,13 @@ class ModelTpServer: self.model_config.context_len - 1, self.max_total_num_tokens - 1, ) + + # Sync random seed + server_args.random_seed = broadcast_recv_input( + [server_args.random_seed], + self.tp_rank, + self.model_runner.tp_group.cpu_group, + )[0] set_random_seed(server_args.random_seed) # Print info @@ -474,9 +481,7 @@ class ModelTpServer: # Forward and sample the next tokens if batch.extend_num_tokens != 0: output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample( - output.next_token_logits, self.model_runner.is_multi_node_tp - ) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -636,9 +641,7 @@ class ModelTpServer: # Forward and sample the next tokens output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample( - output.next_token_logits, self.model_runner.is_multi_node_tp - ) + next_token_ids = batch.sample(output.next_token_logits) # Move logprobs to cpu if output.next_token_logprobs is not None: @@ -879,6 +882,7 @@ def broadcast_recv_input( dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_data, src=0, group=dist_group) + return data else: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index af39065cf..d045be56d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -84,13 +84,20 @@ def set_torch_compile_config(): class CudaGraphRunner: - def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): + def __init__( + self, + model_runner, + max_batch_size_to_capture: int, + use_torch_compile: bool, + disable_padding: bool, + ): self.model_runner = model_runner self.graphs = {} self.input_buffers = {} self.output_buffers = {} self.flashinfer_handlers = {} self.graph_memory_pool = None + self.disable_padding = disable_padding # Common inputs self.max_bs = max_batch_size_to_capture @@ -142,7 +149,10 @@ class CudaGraphRunner: set_torch_compile_config() def can_run(self, batch_size): - return batch_size <= self.max_bs + if self.disable_padding: + return batch_size in self.graphs + else: + return batch_size <= self.max_bs def capture(self, batch_size_list): self.batch_size_list = batch_size_list diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4a3396cf2..a00a73945 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -465,6 +465,7 @@ class ModelRunner: self, max_batch_size_to_capture=max(batch_size_list), use_torch_compile=self.server_args.enable_torch_compile, + disable_padding=self.server_args.disable_cuda_graph_padding, ) try: self.cuda_graph_runner.capture(batch_size_list) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0c5a3c706..fbe3374df 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -24,7 +24,6 @@ import json import logging import multiprocessing as mp import os -import sys import threading import time from http import HTTPStatus @@ -301,27 +300,29 @@ def launch_server( server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path) # Launch processes for multi-node tensor parallelism - if server_args.nnodes > 1: - if server_args.node_rank != 0: - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [ - i for _ in range(server_args.nnodes) for i in range(tp_size_local) - ] - tp_rank_range = list( - range( - server_args.node_rank * tp_size_local, - (server_args.node_rank + 1) * tp_size_local, - ) + if server_args.nnodes > 1 and server_args.node_rank != 0: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + tp_rank_range = list( + range( + server_args.node_rank * tp_size_local, + (server_args.node_rank + 1) * tp_size_local, ) - procs = launch_tp_servers( - gpu_ids, - tp_rank_range, - server_args, - ports[3], - model_overide_args, - ) - while True: - pass + ) + procs = launch_tp_servers( + gpu_ids, + tp_rank_range, + server_args, + ports[3], + model_overide_args, + ) + + try: + for p in procs: + p.join() + finally: + kill_child_process(os.getpid(), including_parent=False) + return # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) @@ -356,15 +357,11 @@ def launch_server( if controller_init_state != "init ok" or detoken_init_state != "init ok": proc_controller.kill() proc_detoken.kill() - print( - f"Initialization failed. controller_init_state: {controller_init_state}", - flush=True, + raise RuntimeError( + "Initialization failed. " + f"controller_init_state: {controller_init_state}, " + f"detoken_init_state: {detoken_init_state}" ) - print( - f"Initialization failed. detoken_init_state: {detoken_init_state}", - flush=True, - ) - sys.exit(1) assert proc_controller.is_alive() and proc_detoken.is_alive() # Add api key authorization @@ -373,12 +370,12 @@ def launch_server( # Send a warmup request t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) ) t.start() - # Listen for requests try: + # Listen for requests uvicorn.run( app, host=server_args.host, @@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs): ) -def _wait_and_warmup(server_args, pipe_finish_writer): +def _wait_and_warmup(server_args, pipe_finish_writer, pid): headers = {} url = server_args.url() if server_args.api_key: @@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if not success: if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_child_process(pid, including_parent=False) + return # Send a warmup request request_name = "/generate" if model_info["is_generation"] else "/encode" @@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer): timeout=600, ) assert res.status_code == 200, f"{res}" - except Exception as e: + except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - print(f"Initialization failed. warmup error: {last_traceback}", flush=True) - sys.exit(1) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_child_process(pid, including_parent=False) + return logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c7120564c..4dd5bacca 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -79,6 +79,7 @@ class ServerArgs: disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False + disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False @@ -393,6 +394,11 @@ class ServerArgs: action="store_true", help="Disable cuda graph.", ) + parser.add_argument( + "--disable-cuda-graph-padding", + action="store_true", + help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.", + ) parser.add_argument( "--disable-disk-cache", action="store_true", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 9761c851a..a15ea1630 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -369,14 +369,11 @@ def kill_parent_process(): """Kill the parent process and all children of the parent process.""" current_process = psutil.Process() parent_process = current_process.parent() - children = parent_process.children(recursive=True) - for child in children: - if child.pid != current_process.pid: - os.kill(child.pid, 9) - os.kill(parent_process.pid, 9) + kill_child_process(parent_process.pid, skip_pid=current_process.pid) -def kill_child_process(pid, including_parent=True): +def kill_child_process(pid, including_parent=True, skip_pid=None): + """Kill the process and all its children process.""" try: parent = psutil.Process(pid) except psutil.NoSuchProcess: @@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True): children = parent.children(recursive=True) for child in children: + if child.pid == skip_pid: + continue try: child.kill() except psutil.NoSuchProcess: