Improve logging & add logit cap (#471)

This commit is contained in:
Lianmin Zheng
2024-05-24 03:48:53 -07:00
committed by GitHub
parent 44c998fcb5
commit 2cea6146d8
12 changed files with 106 additions and 24 deletions

View File

@@ -30,7 +30,7 @@ if __name__ == "__main__":
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": f"{a}, ", "text": f"The capital of France is",
# "input_ids": [[2] * 256] * 196, # "input_ids": [[2] * 256] * 196,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,

View File

@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable) super().__init__(enable=enable)
if tokenizer_path.endswith(".json"):
return
from importlib.metadata import version from importlib.metadata import version
if version("outlines") >= "0.0.35": if version("outlines") >= "0.0.35":

View File

@@ -84,6 +84,9 @@ def get_tokenizer(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name): if is_multimodal_model(tokenizer_name):
processor = get_processor( processor = get_processor(
@@ -170,3 +173,24 @@ def get_processor(
**kwargs, **kwargs,
) )
return processor return processor
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token
self.vocab_size = tokenizer.n_vocab
def encode(self, x):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens):
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")

View File

@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q_Extend, Q_Extend,
@@ -39,6 +45,7 @@ def _fwd_kernel(
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
): ):
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
@@ -90,6 +97,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max) n_e_max = tl.maximum(tl.max(qk, 1), e_max)
@@ -126,6 +137,10 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :] start_n + offs_n[None, :]
) )
@@ -176,6 +191,7 @@ def extend_attention_fwd(
b_seq_len_extend, b_seq_len_extend,
max_len_in_batch, max_len_in_batch,
max_len_extend, max_len_extend,
logit_cap=-1,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -271,6 +287,7 @@ def extend_attention_fwd(
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
logit_cap=logit_cap,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel) cached_kernel = wrap_kernel_launcher(_fwd_kernel)

View File

@@ -1,4 +1,5 @@
import torch import torch
import numpy as np
from torch import nn from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id): def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim self.head_dim = head_dim
self.layer_id = layer_id self.layer_id = layer_id
self.logit_cap = logit_cap
assert np.allclose(scaling, 1.0 / (head_dim**0.5))
from sglang.srt.managers.router.model_runner import global_server_args_dict from sglang.srt.managers.router.model_runner import global_server_args_dict
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata.start_loc, input_metadata.start_loc,
input_metadata.seq_lens, input_metadata.seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
self.logit_cap,
) )
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata.extend_seq_lens, input_metadata.extend_seq_lens,
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.max_extend_len, input_metadata.max_extend_len,
self.logit_cap,
) )
return o return o
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata.max_seq_len, input_metadata.max_seq_len,
input_metadata.other_kv_index, input_metadata.other_kv_index,
input_metadata.total_num_tokens, input_metadata.total_num_tokens,
self.logit_cap,
) )
return o return o

View File

@@ -16,6 +16,12 @@ else:
REDUCE_TORCH_TYPE = torch.float16 REDUCE_TORCH_TYPE = torch.float16
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit @triton.jit
def _fwd_kernel_stage1( def _fwd_kernel_stage1(
Q, Q,
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
).to(REDUCE_TRITON_TYPE) ).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale att_value *= sm_scale
if logit_cap > 0:
att_value = logit_cap * tanh(att_value / logit_cap)
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
max_len_in_batch, max_len_in_batch,
logit_cap,
): ):
BLOCK = 32 BLOCK = 32
# shape constraints # shape constraints
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
) )
@@ -304,6 +317,7 @@ def token_attention_fwd(
max_len_in_batch, max_len_in_batch,
other_kv_index, other_kv_index,
total_num_tokens, total_num_tokens,
logit_cap=-1,
att_m=None, att_m=None,
): ):
if att_m is None: if att_m is None:
@@ -320,6 +334,7 @@ def token_attention_fwd(
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
max_len_in_batch, max_len_in_batch,
logit_cap,
) )
_token_softmax_reducev_fwd( _token_softmax_reducev_fwd(
att_m, att_m,

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import inspect
import uvloop import uvloop
import zmq import zmq
@@ -7,7 +8,7 @@ import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback, graceful_registry
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -85,6 +86,8 @@ def start_detokenizer_process(
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
): ):
graceful_registry(inspect.currentframe().f_code.co_name)
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
except Exception as e: except Exception as e:

View File

@@ -106,8 +106,7 @@ class ModelRpcServer:
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
logger.info( logger.info(f"[rank={self.tp_rank}] "
f"Rank {self.tp_rank}: "
f"max_total_num_token={self.max_total_num_token}, " f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
@@ -752,7 +751,7 @@ def _init_service(port):
protocol_config={ protocol_config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 1800, "sync_request_timeout": 3600,
}, },
) )
t.start() t.start()
@@ -772,7 +771,7 @@ def start_model_process(port):
config={ config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 1800, "sync_request_timeout": 3600,
}, },
) )
break break

View File

@@ -235,8 +235,8 @@ class ModelRunner:
} }
# Init torch distributed # Init torch distributed
logger.debug("Init torch begin.")
torch.cuda.set_device(self.tp_rank) torch.cuda.set_device(self.tp_rank)
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
@@ -244,20 +244,22 @@ class ModelRunner:
init_method=f"tcp://127.0.0.1:{self.nccl_port}", init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
logger.debug("Init torch end.") logger.info(f"[rank={self.tp_rank}] Init torch end.")
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
if total_local_gpu_memory < total_gpu_memory * 0.9:
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
total_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.load_model() self.load_model()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self.init_memory_pool(total_gpu_memory) self.init_memory_pool(total_gpu_memory)
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
def load_model(self): def load_model(self):
logger.info(f"Rank {self.tp_rank}: load weight begin.") logger.info(f"[rank={self.tp_rank}] Load weight begin.")
device_config = DeviceConfig() device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format) load_config = LoadConfig(load_format=self.server_args.load_format)
@@ -283,19 +285,19 @@ class ModelRunner:
parallel_config=None, parallel_config=None,
scheduler_config=None, scheduler_config=None,
) )
logger.info(f"Rank {self.tp_rank}: load weight end. {type(self.model)}") logger.info(f"[rank={self.tp_rank}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
self.tp_rank, distributed=self.tp_size > 1
) * (1 << 30)
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.num_key_value_heads // self.tp_size head_num = self.model_config.num_key_value_heads // self.tp_size
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
) )
max_num_token = int(rest_memory // cell_size) max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token return max_num_token
def init_memory_pool(self, total_gpu_memory): def init_memory_pool(self, total_gpu_memory):

View File

@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
time.sleep(0.5) time.sleep(0.5)
try: try:
requests.get(url + "/get_model_info", timeout=5, headers=headers) requests.get(url + "/get_model_info", timeout=5, headers=headers)
success = True # Set flag to True if request succeeds
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
pass pass
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
res = requests.post( res = requests.post(
url + "/generate", url + "/generate",
json={ json={
"text": "Say this is a warmup request.", "text": "The capital city of France is",
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,

View File

@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return wrapper return wrapper
def get_available_gpu_memory(gpu_id, distributed=True): def get_available_gpu_memory(gpu_id, distributed=False):
""" """
Get available memory for cuda:gpu_id device. Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs. When distributed is True, the available memory is the minimum available memory of all GPUs.

View File

@@ -2,7 +2,8 @@
import base64 import base64
import json import json
import os import logging
import signal
import sys import sys
import threading import threading
import traceback import traceback
@@ -15,6 +16,9 @@ import numpy as np
import requests import requests
logger = logging.getLogger(__name__)
def get_exception_traceback(): def get_exception_traceback():
etype, value, tb = sys.exc_info() etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb)) err_str = "".join(traceback.format_exception(etype, value, tb))
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
raise RuntimeError() raise RuntimeError()
return ret_value[0] return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)