Improve logging & add logit cap (#471)
This commit is contained in:
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": f"{a}, ",
|
"text": f"The capital of France is",
|
||||||
# "input_ids": [[2] * 256] * 196,
|
# "input_ids": [[2] * 256] * 196,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
|
|||||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
|
if tokenizer_path.endswith(".json"):
|
||||||
|
return
|
||||||
|
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
|
||||||
if version("outlines") >= "0.0.35":
|
if version("outlines") >= "0.0.35":
|
||||||
|
|||||||
@@ -84,6 +84,9 @@ def get_tokenizer(
|
|||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
if tokenizer_name.endswith(".json"):
|
||||||
|
return TiktokenTokenizer(tokenizer_name)
|
||||||
|
|
||||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||||
if is_multimodal_model(tokenizer_name):
|
if is_multimodal_model(tokenizer_name):
|
||||||
processor = get_processor(
|
processor = get_processor(
|
||||||
@@ -170,3 +173,24 @@ def get_processor(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
|
class TiktokenTokenizer:
|
||||||
|
def __init__(self, tokenizer_path):
|
||||||
|
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
|
||||||
|
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.eos_token_id = tokenizer.eos_token
|
||||||
|
self.vocab_size = tokenizer.n_vocab
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return self.tokenizer.encode(x)
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.tokenizer.decode(x)
|
||||||
|
|
||||||
|
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
|
||||||
|
return self.tokenizer.decode_batch(batch)
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, index):
|
||||||
|
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
||||||
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
|
|||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def tanh(x):
|
||||||
|
# Tanh is just a scaled sigmoid
|
||||||
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
Q_Extend,
|
Q_Extend,
|
||||||
@@ -39,6 +45,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -90,6 +97,10 @@ def _fwd_kernel(
|
|||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
@@ -126,6 +137,10 @@ def _fwd_kernel(
|
|||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
qk *= sm_scale
|
qk *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||||
start_n + offs_n[None, :]
|
start_n + offs_n[None, :]
|
||||||
)
|
)
|
||||||
@@ -176,6 +191,7 @@ def extend_attention_fwd(
|
|||||||
b_seq_len_extend,
|
b_seq_len_extend,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
|
logit_cap=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -271,6 +287,7 @@ def extend_attention_fwd(
|
|||||||
BLOCK_N=BLOCK_N,
|
BLOCK_N=BLOCK_N,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
logit_cap=logit_cap,
|
||||||
)
|
)
|
||||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
|||||||
|
|
||||||
|
|
||||||
class RadixAttention(nn.Module):
|
class RadixAttention(nn.Module):
|
||||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_q_head_num = num_heads
|
self.tp_q_head_num = num_heads
|
||||||
self.tp_k_head_num = num_kv_heads
|
self.tp_k_head_num = num_kv_heads
|
||||||
self.tp_v_head_num = num_kv_heads
|
self.tp_v_head_num = num_kv_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.logit_cap = logit_cap
|
||||||
|
|
||||||
|
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||||
|
|
||||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||||
|
|
||||||
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
|
|||||||
input_metadata.start_loc,
|
input_metadata.start_loc,
|
||||||
input_metadata.seq_lens,
|
input_metadata.seq_lens,
|
||||||
input_metadata.max_seq_len,
|
input_metadata.max_seq_len,
|
||||||
|
self.logit_cap,
|
||||||
)
|
)
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
|
|||||||
input_metadata.extend_seq_lens,
|
input_metadata.extend_seq_lens,
|
||||||
input_metadata.max_seq_len,
|
input_metadata.max_seq_len,
|
||||||
input_metadata.max_extend_len,
|
input_metadata.max_extend_len,
|
||||||
|
self.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
|
|||||||
input_metadata.max_seq_len,
|
input_metadata.max_seq_len,
|
||||||
input_metadata.other_kv_index,
|
input_metadata.other_kv_index,
|
||||||
input_metadata.total_num_tokens,
|
input_metadata.total_num_tokens,
|
||||||
|
self.logit_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|||||||
@@ -16,6 +16,12 @@ else:
|
|||||||
REDUCE_TORCH_TYPE = torch.float16
|
REDUCE_TORCH_TYPE = torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def tanh(x):
|
||||||
|
# Tanh is just a scaled sigmoid
|
||||||
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel_stage1(
|
def _fwd_kernel_stage1(
|
||||||
Q,
|
Q,
|
||||||
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
|
|||||||
kv_group_num: tl.constexpr,
|
kv_group_num: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
logit_cap: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
|
|||||||
).to(REDUCE_TRITON_TYPE)
|
).to(REDUCE_TRITON_TYPE)
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
att_value = tl.sum(q[None, :] * k, 1)
|
||||||
att_value *= sm_scale
|
att_value *= sm_scale
|
||||||
|
|
||||||
|
if logit_cap > 0:
|
||||||
|
att_value = logit_cap * tanh(att_value / logit_cap)
|
||||||
|
|
||||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
||||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
||||||
|
|
||||||
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
|
|||||||
B_Start_Loc,
|
B_Start_Loc,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
|
logit_cap,
|
||||||
):
|
):
|
||||||
BLOCK = 32
|
BLOCK = 32
|
||||||
# shape constraints
|
# shape constraints
|
||||||
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
|
|||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=Lk,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
|
logit_cap=logit_cap,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
@@ -304,6 +317,7 @@ def token_attention_fwd(
|
|||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
other_kv_index,
|
other_kv_index,
|
||||||
total_num_tokens,
|
total_num_tokens,
|
||||||
|
logit_cap=-1,
|
||||||
att_m=None,
|
att_m=None,
|
||||||
):
|
):
|
||||||
if att_m is None:
|
if att_m is None:
|
||||||
@@ -320,6 +334,7 @@ def token_attention_fwd(
|
|||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
|
logit_cap,
|
||||||
)
|
)
|
||||||
_token_softmax_reducev_fwd(
|
_token_softmax_reducev_fwd(
|
||||||
att_m,
|
att_m,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -7,7 +8,7 @@ import zmq.asyncio
|
|||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback, graceful_registry
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -85,6 +86,8 @@ def start_detokenizer_process(
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
|
graceful_registry(inspect.currentframe().f_code.co_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
manager = DetokenizerManager(server_args, port_args)
|
manager = DetokenizerManager(server_args, port_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -106,8 +106,7 @@ class ModelRpcServer:
|
|||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print info
|
||||||
logger.info(
|
logger.info(f"[rank={self.tp_rank}] "
|
||||||
f"Rank {self.tp_rank}: "
|
|
||||||
f"max_total_num_token={self.max_total_num_token}, "
|
f"max_total_num_token={self.max_total_num_token}, "
|
||||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||||
f"context_len={self.model_config.context_len}, "
|
f"context_len={self.model_config.context_len}, "
|
||||||
@@ -752,7 +751,7 @@ def _init_service(port):
|
|||||||
protocol_config={
|
protocol_config={
|
||||||
"allow_public_attrs": True,
|
"allow_public_attrs": True,
|
||||||
"allow_pickle": True,
|
"allow_pickle": True,
|
||||||
"sync_request_timeout": 1800,
|
"sync_request_timeout": 3600,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
@@ -772,7 +771,7 @@ def start_model_process(port):
|
|||||||
config={
|
config={
|
||||||
"allow_public_attrs": True,
|
"allow_public_attrs": True,
|
||||||
"allow_pickle": True,
|
"allow_pickle": True,
|
||||||
"sync_request_timeout": 1800,
|
"sync_request_timeout": 3600,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -235,8 +235,8 @@ class ModelRunner:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Init torch distributed
|
# Init torch distributed
|
||||||
logger.debug("Init torch begin.")
|
|
||||||
torch.cuda.set_device(self.tp_rank)
|
torch.cuda.set_device(self.tp_rank)
|
||||||
|
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||||
torch.distributed.init_process_group(
|
torch.distributed.init_process_group(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
@@ -244,20 +244,22 @@ class ModelRunner:
|
|||||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
logger.debug("Init torch end.")
|
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
||||||
|
|
||||||
|
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
||||||
|
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
||||||
|
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
||||||
|
|
||||||
total_gpu_memory = get_available_gpu_memory(
|
|
||||||
self.tp_rank, distributed=self.tp_size > 1
|
|
||||||
) * (1 << 30)
|
|
||||||
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
|
||||||
self.init_memory_pool(total_gpu_memory)
|
self.init_memory_pool(total_gpu_memory)
|
||||||
|
|
||||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
||||||
|
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||||
@@ -283,19 +285,19 @@ class ModelRunner:
|
|||||||
parallel_config=None,
|
parallel_config=None,
|
||||||
scheduler_config=None,
|
scheduler_config=None,
|
||||||
)
|
)
|
||||||
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")
|
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
||||||
|
f"Type={type(self.model).__name__}. "
|
||||||
|
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||||
self.tp_rank, distributed=self.tp_size > 1
|
|
||||||
) * (1 << 30)
|
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||||
1 - self.mem_fraction_static
|
1 - self.mem_fraction_static
|
||||||
)
|
)
|
||||||
max_num_token = int(rest_memory // cell_size)
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||||
return max_num_token
|
return max_num_token
|
||||||
|
|
||||||
def init_memory_pool(self, total_gpu_memory):
|
def init_memory_pool(self, total_gpu_memory):
|
||||||
|
|||||||
@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
try:
|
try:
|
||||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
success = True # Set flag to True if request succeeds
|
|
||||||
break
|
break
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
pass
|
pass
|
||||||
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
res = requests.post(
|
res = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": "Say this is a warmup request.",
|
"text": "The capital city of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 16,
|
"max_new_tokens": 16,
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_available_gpu_memory(gpu_id, distributed=True):
|
def get_available_gpu_memory(gpu_id, distributed=False):
|
||||||
"""
|
"""
|
||||||
Get available memory for cuda:gpu_id device.
|
Get available memory for cuda:gpu_id device.
|
||||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import logging
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
@@ -15,6 +16,9 @@ import numpy as np
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_exception_traceback():
|
def get_exception_traceback():
|
||||||
etype, value, tb = sys.exc_info()
|
etype, value, tb = sys.exc_info()
|
||||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||||
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
|||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
return ret_value[0]
|
return ret_value[0]
|
||||||
|
|
||||||
|
|
||||||
|
def graceful_registry(sub_module_name):
|
||||||
|
def graceful_shutdown(signum, frame):
|
||||||
|
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
||||||
|
if signum == signal.SIGTERM:
|
||||||
|
logger.info(f"{sub_module_name} recive sigterm")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||||
Reference in New Issue
Block a user