Improve logging & add logit cap (#471)
This commit is contained in:
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
if tokenizer_path.endswith(".json"):
|
||||
return
|
||||
|
||||
from importlib.metadata import version
|
||||
|
||||
if version("outlines") >= "0.0.35":
|
||||
|
||||
@@ -84,6 +84,9 @@ def get_tokenizer(
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if tokenizer_name.endswith(".json"):
|
||||
return TiktokenTokenizer(tokenizer_name)
|
||||
|
||||
"""Gets a tokenizer for the given model name via Huggingface."""
|
||||
if is_multimodal_model(tokenizer_name):
|
||||
processor = get_processor(
|
||||
@@ -170,3 +173,24 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
return processor
|
||||
|
||||
|
||||
class TiktokenTokenizer:
|
||||
def __init__(self, tokenizer_path):
|
||||
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
|
||||
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = tokenizer.eos_token
|
||||
self.vocab_size = tokenizer.n_vocab
|
||||
|
||||
def encode(self, x):
|
||||
return self.tokenizer.encode(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.tokenizer.decode(x)
|
||||
|
||||
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def convert_ids_to_tokens(self, index):
|
||||
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|
||||
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q_Extend,
|
||||
@@ -39,6 +45,7 @@ def _fwd_kernel(
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -90,6 +97,10 @@ def _fwd_kernel(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
@@ -126,6 +137,10 @@ def _fwd_kernel(
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||
start_n + offs_n[None, :]
|
||||
)
|
||||
@@ -176,6 +191,7 @@ def extend_attention_fwd(
|
||||
b_seq_len_extend,
|
||||
max_len_in_batch,
|
||||
max_len_extend,
|
||||
logit_cap=-1,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -271,6 +287,7 @@ def extend_attention_fwd(
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id):
|
||||
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
self.tp_k_head_num = num_kv_heads
|
||||
self.tp_v_head_num = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.layer_id = layer_id
|
||||
self.logit_cap = logit_cap
|
||||
|
||||
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
|
||||
|
||||
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
||||
|
||||
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.start_loc,
|
||||
input_metadata.seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
self.logit_cap,
|
||||
)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.extend_seq_lens,
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.max_extend_len,
|
||||
self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
|
||||
input_metadata.max_seq_len,
|
||||
input_metadata.other_kv_index,
|
||||
input_metadata.total_num_tokens,
|
||||
self.logit_cap,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
@@ -16,6 +16,12 @@ else:
|
||||
REDUCE_TORCH_TYPE = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
|
||||
).to(REDUCE_TRITON_TYPE)
|
||||
att_value = tl.sum(q[None, :] * k, 1)
|
||||
att_value *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
att_value = logit_cap * tanh(att_value / logit_cap)
|
||||
|
||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
||||
|
||||
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
max_len_in_batch,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
# shape constraints
|
||||
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BLOCK_N=BLOCK,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
@@ -304,6 +317,7 @@ def token_attention_fwd(
|
||||
max_len_in_batch,
|
||||
other_kv_index,
|
||||
total_num_tokens,
|
||||
logit_cap=-1,
|
||||
att_m=None,
|
||||
):
|
||||
if att_m is None:
|
||||
@@ -320,6 +334,7 @@ def token_attention_fwd(
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
logit_cap,
|
||||
)
|
||||
_token_softmax_reducev_fwd(
|
||||
att_m,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -7,7 +8,7 @@ import zmq.asyncio
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.utils import get_exception_traceback, graceful_registry
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -85,6 +86,8 @@ def start_detokenizer_process(
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
graceful_registry(inspect.currentframe().f_code.co_name)
|
||||
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
except Exception as e:
|
||||
|
||||
@@ -106,8 +106,7 @@ class ModelRpcServer:
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
logger.info(
|
||||
f"Rank {self.tp_rank}: "
|
||||
logger.info(f"[rank={self.tp_rank}] "
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
@@ -752,7 +751,7 @@ def _init_service(port):
|
||||
protocol_config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 1800,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
t.start()
|
||||
@@ -772,7 +771,7 @@ def start_model_process(port):
|
||||
config={
|
||||
"allow_public_attrs": True,
|
||||
"allow_pickle": True,
|
||||
"sync_request_timeout": 1800,
|
||||
"sync_request_timeout": 3600,
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
@@ -235,8 +235,8 @@ class ModelRunner:
|
||||
}
|
||||
|
||||
# Init torch distributed
|
||||
logger.debug("Init torch begin.")
|
||||
torch.cuda.set_device(self.tp_rank)
|
||||
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
@@ -244,20 +244,22 @@ class ModelRunner:
|
||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
logger.debug("Init torch end.")
|
||||
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||
|
||||
if self.tp_size > 1:
|
||||
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
||||
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
||||
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
||||
self.load_model()
|
||||
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
|
||||
self.init_memory_pool(total_gpu_memory)
|
||||
|
||||
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
||||
|
||||
def load_model(self):
|
||||
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
||||
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
||||
|
||||
device_config = DeviceConfig()
|
||||
load_config = LoadConfig(load_format=self.server_args.load_format)
|
||||
@@ -283,19 +285,19 @@ class ModelRunner:
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
)
|
||||
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}")
|
||||
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
||||
f"Type={type(self.model).__name__}. "
|
||||
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
self.tp_rank, distributed=self.tp_size > 1
|
||||
) * (1 << 30)
|
||||
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.num_key_value_heads // self.tp_size
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
max_num_token = int(rest_memory // cell_size)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
def init_memory_pool(self, total_gpu_memory):
|
||||
|
||||
@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
success = True # Set flag to True if request succeeds
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
pass
|
||||
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "Say this is a warmup request.",
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
|
||||
@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_available_gpu_memory(gpu_id, distributed=True):
|
||||
def get_available_gpu_memory(gpu_id, distributed=False):
|
||||
"""
|
||||
Get available memory for cuda:gpu_id device.
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
@@ -15,6 +16,9 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def graceful_registry(sub_module_name):
|
||||
def graceful_shutdown(signum, frame):
|
||||
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
|
||||
if signum == signal.SIGTERM:
|
||||
logger.info(f"{sub_module_name} recive sigterm")
|
||||
|
||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||
Reference in New Issue
Block a user