From 2cea6146d8735780da602c0dfa0569b0fb5d47ba Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 24 May 2024 03:48:53 -0700 Subject: [PATCH] Improve logging & add logit cap (#471) --- benchmark/latency_throughput/test_latency.py | 2 +- python/sglang/srt/constrained/fsm_cache.py | 3 ++ python/sglang/srt/hf_transformers_utils.py | 24 ++++++++++++++++ python/sglang/srt/layers/extend_attention.py | 17 +++++++++++ python/sglang/srt/layers/radix_attention.py | 9 +++++- python/sglang/srt/layers/token_attention.py | 15 ++++++++++ .../srt/managers/detokenizer_manager.py | 5 +++- .../sglang/srt/managers/router/model_rpc.py | 7 ++--- .../srt/managers/router/model_runner.py | 28 ++++++++++--------- python/sglang/srt/server.py | 3 +- python/sglang/srt/utils.py | 2 +- python/sglang/utils.py | 15 +++++++++- 12 files changed, 106 insertions(+), 24 deletions(-) diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index 140b959ec..f65f390e9 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -30,7 +30,7 @@ if __name__ == "__main__": response = requests.post( url + "/generate", json={ - "text": f"{a}, ", + "text": f"The capital of France is", # "input_ids": [[2] * 256] * 196, "sampling_params": { "temperature": 0, diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 9c467ab33..fb1588f95 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -6,6 +6,9 @@ class FSMCache(BaseCache): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): super().__init__(enable=enable) + if tokenizer_path.endswith(".json"): + return + from importlib.metadata import version if version("outlines") >= "0.0.35": diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 9d2f917d8..b34168462 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -84,6 +84,9 @@ def get_tokenizer( tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + if tokenizer_name.endswith(".json"): + return TiktokenTokenizer(tokenizer_name) + """Gets a tokenizer for the given model name via Huggingface.""" if is_multimodal_model(tokenizer_name): processor = get_processor( @@ -170,3 +173,24 @@ def get_processor( **kwargs, ) return processor + + +class TiktokenTokenizer: + def __init__(self, tokenizer_path): + import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper + tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path) + self.tokenizer = tokenizer + self.eos_token_id = tokenizer.eos_token + self.vocab_size = tokenizer.n_vocab + + def encode(self, x): + return self.tokenizer.encode(x) + + def decode(self, x): + return self.tokenizer.decode(x) + + def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens): + return self.tokenizer.decode_batch(batch) + + def convert_ids_to_tokens(self, index): + return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore") \ No newline at end of file diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index ce402a910..5b1dba40f 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher CUDA_CAPABILITY = torch.cuda.get_device_capability() +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + @triton.jit def _fwd_kernel( Q_Extend, @@ -39,6 +45,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + logit_cap: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -90,6 +97,10 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) n_e_max = tl.maximum(tl.max(qk, 1), e_max) @@ -126,6 +137,10 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( start_n + offs_n[None, :] ) @@ -176,6 +191,7 @@ def extend_attention_fwd( b_seq_len_extend, max_len_in_batch, max_len_extend, + logit_cap=-1, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -271,6 +287,7 @@ def extend_attention_fwd( BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, + logit_cap=logit_cap, ) cached_kernel = wrap_kernel_launcher(_fwd_kernel) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index cef2c3b7f..59d1a54da 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -1,4 +1,5 @@ import torch +import numpy as np from torch import nn from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd @@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata class RadixAttention(nn.Module): - def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id): + def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1): super().__init__() self.tp_q_head_num = num_heads self.tp_k_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads self.head_dim = head_dim self.layer_id = layer_id + self.logit_cap = logit_cap + + assert np.allclose(scaling, 1.0 / (head_dim**0.5)) from sglang.srt.managers.router.model_runner import global_server_args_dict @@ -38,6 +42,7 @@ class RadixAttention(nn.Module): input_metadata.start_loc, input_metadata.seq_lens, input_metadata.max_seq_len, + self.logit_cap, ) self.store_kv_cache(k, v, input_metadata) @@ -62,6 +67,7 @@ class RadixAttention(nn.Module): input_metadata.extend_seq_lens, input_metadata.max_seq_len, input_metadata.max_extend_len, + self.logit_cap, ) return o @@ -82,6 +88,7 @@ class RadixAttention(nn.Module): input_metadata.max_seq_len, input_metadata.other_kv_index, input_metadata.total_num_tokens, + self.logit_cap, ) return o diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 7b03e3ffe..58e3fa611 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -16,6 +16,12 @@ else: REDUCE_TORCH_TYPE = torch.float16 +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + @triton.jit def _fwd_kernel_stage1( Q, @@ -35,6 +41,7 @@ def _fwd_kernel_stage1( kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + logit_cap: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -77,6 +84,10 @@ def _fwd_kernel_stage1( ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) att_value *= sm_scale + + if logit_cap > 0: + att_value = logit_cap * tanh(att_value / logit_cap) + off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) @@ -165,6 +176,7 @@ def _token_att_m_fwd( B_Start_Loc, B_Seqlen, max_len_in_batch, + logit_cap, ): BLOCK = 32 # shape constraints @@ -223,6 +235,7 @@ def _token_att_m_fwd( kv_group_num=kv_group_num, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, + logit_cap=logit_cap, num_warps=num_warps, num_stages=1, ) @@ -304,6 +317,7 @@ def token_attention_fwd( max_len_in_batch, other_kv_index, total_num_tokens, + logit_cap=-1, att_m=None, ): if att_m is None: @@ -320,6 +334,7 @@ def token_attention_fwd( b_start_loc, b_seq_len, max_len_in_batch, + logit_cap, ) _token_softmax_reducev_fwd( att_m, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 52bad9792..eeefbe0ba 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -1,4 +1,5 @@ import asyncio +import inspect import uvloop import zmq @@ -7,7 +8,7 @@ import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.utils import get_exception_traceback +from sglang.utils import get_exception_traceback, graceful_registry asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -85,6 +86,8 @@ def start_detokenizer_process( port_args: PortArgs, pipe_writer, ): + graceful_registry(inspect.currentframe().f_code.co_name) + try: manager = DetokenizerManager(server_args, port_args) except Exception as e: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 2873ef4c5..6abb20b25 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -106,8 +106,7 @@ class ModelRpcServer: set_random_seed(server_args.random_seed) # Print info - logger.info( - f"Rank {self.tp_rank}: " + logger.info(f"[rank={self.tp_rank}] " f"max_total_num_token={self.max_total_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, " f"context_len={self.model_config.context_len}, " @@ -752,7 +751,7 @@ def _init_service(port): protocol_config={ "allow_public_attrs": True, "allow_pickle": True, - "sync_request_timeout": 1800, + "sync_request_timeout": 3600, }, ) t.start() @@ -772,7 +771,7 @@ def start_model_process(port): config={ "allow_public_attrs": True, "allow_pickle": True, - "sync_request_timeout": 1800, + "sync_request_timeout": 3600, }, ) break diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index cea08f5da..34b160789 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -235,8 +235,8 @@ class ModelRunner: } # Init torch distributed - logger.debug("Init torch begin.") torch.cuda.set_device(self.tp_rank) + logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") torch.distributed.init_process_group( backend="nccl", world_size=self.tp_size, @@ -244,20 +244,22 @@ class ModelRunner: init_method=f"tcp://127.0.0.1:{self.nccl_port}", ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - logger.debug("Init torch end.") + logger.info(f"[rank={self.tp_rank}] Init torch end.") + + total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) + + if self.tp_size > 1: + total_local_gpu_memory = get_available_gpu_memory(self.tp_rank) + if total_local_gpu_memory < total_gpu_memory * 0.9: + raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.") - total_gpu_memory = get_available_gpu_memory( - self.tp_rank, distributed=self.tp_size > 1 - ) * (1 << 30) - # logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB") self.load_model() - # logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB") self.init_memory_pool(total_gpu_memory) self.is_multimodal_model = is_multimodal_model(self.model_config) def load_model(self): - logger.info(f"Rank {self.tp_rank}: load weight begin.") + logger.info(f"[rank={self.tp_rank}] Load weight begin.") device_config = DeviceConfig() load_config = LoadConfig(load_format=self.server_args.load_format) @@ -283,19 +285,19 @@ class ModelRunner: parallel_config=None, scheduler_config=None, ) - logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}") + logger.info(f"[rank={self.tp_rank}] Load weight end. " + f"Type={type(self.model).__name__}. " + f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") def profile_max_num_token(self, total_gpu_memory): - available_gpu_memory = get_available_gpu_memory( - self.tp_rank, distributed=self.tp_size > 1 - ) * (1 << 30) + available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) head_dim = self.model_config.head_dim head_num = self.model_config.num_key_value_heads // self.tp_size cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) - max_num_token = int(rest_memory // cell_size) + max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token def init_memory_pool(self, total_gpu_memory): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a94359707..695e129ed 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg time.sleep(0.5) try: requests.get(url + "/get_model_info", timeout=5, headers=headers) - success = True # Set flag to True if request succeeds break except requests.exceptions.RequestException as e: pass @@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg res = requests.post( url + "/generate", json={ - "text": "Say this is a warmup request.", + "text": "The capital city of France is", "sampling_params": { "temperature": 0, "max_new_tokens": 16, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 981e82152..f006ad6b1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0): return wrapper -def get_available_gpu_memory(gpu_id, distributed=True): +def get_available_gpu_memory(gpu_id, distributed=False): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 332551aa0..e4a3e9adb 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -2,7 +2,8 @@ import base64 import json -import os +import logging +import signal import sys import threading import traceback @@ -15,6 +16,9 @@ import numpy as np import requests +logger = logging.getLogger(__name__) + + def get_exception_traceback(): etype, value, tb = sys.exc_info() err_str = "".join(traceback.format_exception(etype, value, tb)) @@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): raise RuntimeError() return ret_value[0] + + +def graceful_registry(sub_module_name): + def graceful_shutdown(signum, frame): + logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...") + if signum == signal.SIGTERM: + logger.info(f"{sub_module_name} recive sigterm") + + signal.signal(signal.SIGTERM, graceful_shutdown) \ No newline at end of file