Improve multi-node stability (#1171)
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
"""Launch the inference server."""
|
"""Launch the inference server."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from sglang.srt.server import launch_server
|
from sglang.srt.server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -11,4 +13,9 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server_args = ServerArgs.from_cli_args(args)
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
launch_server(server_args)
|
try:
|
||||||
|
launch_server(server_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
kill_child_process(os.getpid(), including_parent=False)
|
||||||
|
|||||||
@@ -233,6 +233,8 @@ class TiktokenTokenizer:
|
|||||||
}
|
}
|
||||||
assert tok_dict["word_split"] == "V1"
|
assert tok_dict["word_split"] == "V1"
|
||||||
|
|
||||||
|
default_allowed_special = None
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
||||||
@@ -246,14 +248,18 @@ class TiktokenTokenizer:
|
|||||||
for bytes_list in tok_dict["default_allowed_special"]
|
for bytes_list in tok_dict["default_allowed_special"]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
default_allowed_special = None
|
|
||||||
if "vocab_size" in tok_dict:
|
if "vocab_size" in tok_dict:
|
||||||
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
||||||
|
|
||||||
|
PAD = "<|pad|>"
|
||||||
|
EOS = "<|eos|>"
|
||||||
|
SEP = "<|separator|>"
|
||||||
|
|
||||||
|
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
||||||
|
|
||||||
tokenizer = tiktoken.Encoding(**kwargs)
|
tokenizer = tiktoken.Encoding(**kwargs)
|
||||||
tokenizer._default_allowed_special = default_allowed_special or set()
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
||||||
tokenizer._default_allowed_special |= {"<|separator|>"}
|
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
||||||
|
|
||||||
def encode_patched(
|
def encode_patched(
|
||||||
self,
|
self,
|
||||||
@@ -270,14 +276,14 @@ class TiktokenTokenizer:
|
|||||||
self,
|
self,
|
||||||
text,
|
text,
|
||||||
allowed_special=allowed_special,
|
allowed_special=allowed_special,
|
||||||
disallowed_special=disallowed_special,
|
disallowed_special=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
||||||
|
|
||||||
# Convert to HF interface
|
# Convert to HF interface
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
self.eos_token_id = tokenizer._special_tokens[EOS]
|
||||||
self.vocab_size = tokenizer.n_vocab
|
self.vocab_size = tokenizer.n_vocab
|
||||||
self.chat_template = Template(
|
self.chat_template = Template(
|
||||||
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
||||||
|
|||||||
@@ -212,6 +212,4 @@ def start_controller_process(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
||||||
finally:
|
finally:
|
||||||
for w in controller.workers:
|
|
||||||
os.kill(w.proc.pid, 9)
|
|
||||||
kill_parent_process()
|
kill_parent_process()
|
||||||
|
|||||||
@@ -167,6 +167,4 @@ def start_controller_process(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||||
finally:
|
finally:
|
||||||
for t in controller.tp_procs:
|
|
||||||
os.kill(t.pid, 9)
|
|
||||||
kill_parent_process()
|
kill_parent_process()
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
"""Meta data for requests and batches"""
|
"""Meta data for requests and batches"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@@ -270,7 +269,7 @@ class Req:
|
|||||||
|
|
||||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||||
# TODO(lsyin): fix token fusion
|
# TODO(lsyin): fix token fusion
|
||||||
logging.warning(
|
logger.warning(
|
||||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
@@ -753,7 +752,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
|
def sample(self, logits: torch.Tensor):
|
||||||
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
@@ -791,7 +790,7 @@ class ScheduleBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not torch.all(success):
|
if not torch.all(success):
|
||||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
|
||||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||||
argmax_ids = torch.argmax(probs, dim=-1)
|
argmax_ids = torch.argmax(probs, dim=-1)
|
||||||
batch_next_token_ids = torch.where(
|
batch_next_token_ids = torch.where(
|
||||||
@@ -808,16 +807,6 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
|
||||||
|
|
||||||
if is_multi_node_tp:
|
|
||||||
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
|
|
||||||
# that can cause the TP workers to generate different tokens, so we need to
|
|
||||||
# sync here
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
batch_next_token_ids,
|
|
||||||
op=dist.ReduceOp.MIN,
|
|
||||||
group=get_tensor_model_parallel_group().device_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
return batch_next_token_ids
|
return batch_next_token_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
|
|||||||
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
||||||
try:
|
try:
|
||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
except RuntimeError:
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Sampling error: {e}")
|
||||||
batch_next_token_ids = torch.zeros(
|
batch_next_token_ids = torch.zeros(
|
||||||
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -133,6 +133,13 @@ class ModelTpServer:
|
|||||||
self.model_config.context_len - 1,
|
self.model_config.context_len - 1,
|
||||||
self.max_total_num_tokens - 1,
|
self.max_total_num_tokens - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Sync random seed
|
||||||
|
server_args.random_seed = broadcast_recv_input(
|
||||||
|
[server_args.random_seed],
|
||||||
|
self.tp_rank,
|
||||||
|
self.model_runner.tp_group.cpu_group,
|
||||||
|
)[0]
|
||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print info
|
||||||
@@ -474,9 +481,7 @@ class ModelTpServer:
|
|||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
next_token_ids = batch.sample(
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
@@ -636,9 +641,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
next_token_ids = batch.sample(
|
next_token_ids = batch.sample(output.next_token_logits)
|
||||||
output.next_token_logits, self.model_runner.is_multi_node_tp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if output.next_token_logprobs is not None:
|
if output.next_token_logprobs is not None:
|
||||||
@@ -879,6 +882,7 @@ def broadcast_recv_input(
|
|||||||
|
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||||
|
return data
|
||||||
else:
|
else:
|
||||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||||
|
|||||||
@@ -84,13 +84,20 @@ def set_torch_compile_config():
|
|||||||
|
|
||||||
|
|
||||||
class CudaGraphRunner:
|
class CudaGraphRunner:
|
||||||
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_runner,
|
||||||
|
max_batch_size_to_capture: int,
|
||||||
|
use_torch_compile: bool,
|
||||||
|
disable_padding: bool,
|
||||||
|
):
|
||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.input_buffers = {}
|
self.input_buffers = {}
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.flashinfer_handlers = {}
|
self.flashinfer_handlers = {}
|
||||||
self.graph_memory_pool = None
|
self.graph_memory_pool = None
|
||||||
|
self.disable_padding = disable_padding
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.max_bs = max_batch_size_to_capture
|
self.max_bs = max_batch_size_to_capture
|
||||||
@@ -142,7 +149,10 @@ class CudaGraphRunner:
|
|||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
def can_run(self, batch_size):
|
def can_run(self, batch_size):
|
||||||
return batch_size <= self.max_bs
|
if self.disable_padding:
|
||||||
|
return batch_size in self.graphs
|
||||||
|
else:
|
||||||
|
return batch_size <= self.max_bs
|
||||||
|
|
||||||
def capture(self, batch_size_list):
|
def capture(self, batch_size_list):
|
||||||
self.batch_size_list = batch_size_list
|
self.batch_size_list = batch_size_list
|
||||||
|
|||||||
@@ -465,6 +465,7 @@ class ModelRunner:
|
|||||||
self,
|
self,
|
||||||
max_batch_size_to_capture=max(batch_size_list),
|
max_batch_size_to_capture=max(batch_size_list),
|
||||||
use_torch_compile=self.server_args.enable_torch_compile,
|
use_torch_compile=self.server_args.enable_torch_compile,
|
||||||
|
disable_padding=self.server_args.disable_cuda_graph_padding,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.cuda_graph_runner.capture(batch_size_list)
|
self.cuda_graph_runner.capture(batch_size_list)
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@@ -301,27 +300,29 @@ def launch_server(
|
|||||||
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
||||||
|
|
||||||
# Launch processes for multi-node tensor parallelism
|
# Launch processes for multi-node tensor parallelism
|
||||||
if server_args.nnodes > 1:
|
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
||||||
if server_args.node_rank != 0:
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||||
gpu_ids = [
|
tp_rank_range = list(
|
||||||
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
range(
|
||||||
]
|
server_args.node_rank * tp_size_local,
|
||||||
tp_rank_range = list(
|
(server_args.node_rank + 1) * tp_size_local,
|
||||||
range(
|
|
||||||
server_args.node_rank * tp_size_local,
|
|
||||||
(server_args.node_rank + 1) * tp_size_local,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
procs = launch_tp_servers(
|
)
|
||||||
gpu_ids,
|
procs = launch_tp_servers(
|
||||||
tp_rank_range,
|
gpu_ids,
|
||||||
server_args,
|
tp_rank_range,
|
||||||
ports[3],
|
server_args,
|
||||||
model_overide_args,
|
ports[3],
|
||||||
)
|
model_overide_args,
|
||||||
while True:
|
)
|
||||||
pass
|
|
||||||
|
try:
|
||||||
|
for p in procs:
|
||||||
|
p.join()
|
||||||
|
finally:
|
||||||
|
kill_child_process(os.getpid(), including_parent=False)
|
||||||
|
return
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||||
@@ -356,15 +357,11 @@ def launch_server(
|
|||||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||||
proc_controller.kill()
|
proc_controller.kill()
|
||||||
proc_detoken.kill()
|
proc_detoken.kill()
|
||||||
print(
|
raise RuntimeError(
|
||||||
f"Initialization failed. controller_init_state: {controller_init_state}",
|
"Initialization failed. "
|
||||||
flush=True,
|
f"controller_init_state: {controller_init_state}, "
|
||||||
|
f"detoken_init_state: {detoken_init_state}"
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||||
|
|
||||||
# Add api key authorization
|
# Add api key authorization
|
||||||
@@ -373,12 +370,12 @@ def launch_server(
|
|||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
# Listen for requests
|
|
||||||
try:
|
try:
|
||||||
|
# Listen for requests
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host=server_args.host,
|
host=server_args.host,
|
||||||
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
if not success:
|
if not success:
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(last_traceback)
|
pipe_finish_writer.send(last_traceback)
|
||||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
sys.exit(1)
|
kill_child_process(pid, including_parent=False)
|
||||||
|
return
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
timeout=600,
|
timeout=600,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200, f"{res}"
|
assert res.status_code == 200, f"{res}"
|
||||||
except Exception as e:
|
except Exception:
|
||||||
last_traceback = get_exception_traceback()
|
last_traceback = get_exception_traceback()
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(last_traceback)
|
pipe_finish_writer.send(last_traceback)
|
||||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
sys.exit(1)
|
kill_child_process(pid, including_parent=False)
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class ServerArgs:
|
|||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
|
disable_cuda_graph_padding: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
@@ -393,6 +394,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable cuda graph.",
|
help="Disable cuda graph.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-cuda-graph-padding",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-disk-cache",
|
"--disable-disk-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -369,14 +369,11 @@ def kill_parent_process():
|
|||||||
"""Kill the parent process and all children of the parent process."""
|
"""Kill the parent process and all children of the parent process."""
|
||||||
current_process = psutil.Process()
|
current_process = psutil.Process()
|
||||||
parent_process = current_process.parent()
|
parent_process = current_process.parent()
|
||||||
children = parent_process.children(recursive=True)
|
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
||||||
for child in children:
|
|
||||||
if child.pid != current_process.pid:
|
|
||||||
os.kill(child.pid, 9)
|
|
||||||
os.kill(parent_process.pid, 9)
|
|
||||||
|
|
||||||
|
|
||||||
def kill_child_process(pid, including_parent=True):
|
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
||||||
|
"""Kill the process and all its children process."""
|
||||||
try:
|
try:
|
||||||
parent = psutil.Process(pid)
|
parent = psutil.Process(pid)
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
@@ -384,6 +381,8 @@ def kill_child_process(pid, including_parent=True):
|
|||||||
|
|
||||||
children = parent.children(recursive=True)
|
children = parent.children(recursive=True)
|
||||||
for child in children:
|
for child in children:
|
||||||
|
if child.pid == skip_pid:
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
child.kill()
|
child.kill()
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
|
|||||||
Reference in New Issue
Block a user