From e1792cca2491af86f29782a3b83533a6566ac75b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 18 Jul 2024 23:28:40 -0700 Subject: [PATCH] Remove cached triton launcher (#656) --- .../layers/context_flashattention_nopad.py | 29 ------- python/sglang/srt/layers/extend_attention.py | 39 --------- python/sglang/srt/layers/token_attention.py | 50 ------------ python/sglang/srt/server.py | 27 +------ python/sglang/srt/utils.py | 80 +++---------------- 5 files changed, 15 insertions(+), 210 deletions(-) diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/context_flashattention_nopad.py index 0c3102c3f..ee1227dbd 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/context_flashattention_nopad.py @@ -4,8 +4,6 @@ import torch import triton import triton.language as tl -from sglang.srt.utils import wrap_kernel_launcher - CUDA_CAPABILITY = torch.cuda.get_device_capability() @@ -119,9 +117,6 @@ def _fwd_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) -cached_kernel = None - - def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): if CUDA_CAPABILITY[0] >= 8: BLOCK = 128 @@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) num_warps = 4 if Lk <= 64 else 8 - global cached_kernel - if cached_kernel: - cached_kernel( - grid, - num_warps, - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - k.stride(0), - k.stride(1), - v.stride(0), - v.stride(1), - o.stride(0), - o.stride(1), - ) - return - _fwd_kernel[grid]( q, k, @@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - cached_kernel = wrap_kernel_launcher(_fwd_kernel) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 41c2ca7d1..833c6b118 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -3,7 +3,6 @@ import triton import triton.language as tl from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd -from sglang.srt.utils import wrap_kernel_launcher CUDA_CAPABILITY = torch.cuda.get_device_capability() @@ -172,9 +171,6 @@ def _fwd_kernel( tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) -cached_kernel = None - - def extend_attention_fwd( q_extend, k_extend, @@ -222,40 +218,6 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 - global cached_kernel - if cached_kernel: - cached_kernel( - grid, - num_warps, - q_extend, - k_extend, - v_extend, - o_extend, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_start_loc_extend, - b_seq_len_extend, - sm_scale, - kv_group_num, - q_extend.stride(0), - q_extend.stride(1), - k_extend.stride(0), - k_extend.stride(1), - v_extend.stride(0), - v_extend.stride(1), - o_extend.stride(0), - o_extend.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - v_buffer.stride(0), - v_buffer.stride(1), - req_to_tokens.stride(0), - ) - return - _fwd_kernel[grid]( q_extend, k_extend, @@ -290,7 +252,6 @@ def extend_attention_fwd( num_stages=num_stages, logit_cap=logit_cap, ) - cached_kernel = wrap_kernel_launcher(_fwd_kernel) def redundant_attention( diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index d8db1e01d..688e1ddd2 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -6,7 +6,6 @@ import triton import triton.language as tl from sglang.srt.server import global_server_args_dict -from sglang.srt.utils import wrap_kernel_launcher if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -162,10 +161,6 @@ def _fwd_kernel_stage2( tl.store(out_ptrs, acc) -cached_kernel_stage1 = None -cached_kernel_stage2 = None - - def _token_att_m_fwd( q, k_buffer, @@ -194,28 +189,6 @@ def _token_att_m_fwd( else: num_warps = 2 - global cached_kernel_stage1 - if cached_kernel_stage1: - cached_kernel_stage1( - grid, - num_warps, - q, - k_buffer, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - att_out, - Req_to_tokens.stride(0), - q.stride(0), - q.stride(1), - k_buffer.stride(0), - k_buffer.stride(1), - att_out.stride(0), - ) - return - _fwd_kernel_stage1[grid]( q, k_buffer, @@ -238,7 +211,6 @@ def _token_att_m_fwd( num_warps=num_warps, num_stages=1, ) - cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1) def _token_softmax_reducev_fwd( @@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd( num_warps = 1 - global cached_kernel_stage2 - if cached_kernel_stage2: - cached_kernel_stage2( - grid, - num_warps, - logics, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logics.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), - ) - return - _fwd_kernel_stage2[grid]( logics, v_buffer, @@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd( num_warps=num_warps, num_stages=3, ) - cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2) def token_attention_fwd( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a0b6b0f69..749cb774d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -51,6 +51,7 @@ from sglang.srt.utils import ( allocate_init_ports, assert_pkg_version, enable_show_time_cost, + set_ulimit, ) from sglang.utils import get_exception_traceback @@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs): } -def _set_ulimit(target_soft_limit=65535): - import resource - - resource_type = resource.RLIMIT_NOFILE - current_soft, current_hard = resource.getrlimit(resource_type) - - if current_soft >= target_soft_limit: - logger.info( - f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}" - ) - else: - try: - resource.setrlimit(resource_type, (target_soft_limit, current_hard)) - new_soft, new_hard = resource.getrlimit(resource_type) - logger.info( - f"Successfully set new limits: soft={new_soft}, hard={new_hard}" - ) - except ValueError as e: - logger.warn(f"Failed to set new limits: {e}") - logger.info( - f"Limits remain unchanged: soft={current_soft}, hard={current_hard}" - ) - - def launch_server( server_args: ServerArgs, model_overide_args: Optional[dict] = None, @@ -186,7 +163,7 @@ def launch_server( os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0" - _set_ulimit() + set_ulimit() if server_args.show_time_cost: enable_show_time_cost() if server_args.disable_disk_cache: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 66f051ea7..a88aa894d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -5,6 +5,7 @@ import fcntl import logging import os import random +import resource import socket import struct import time @@ -16,6 +17,7 @@ import numpy as np import psutil import requests import torch +import torch.distributed as dist import triton from fastapi.responses import JSONResponse from packaging import version as pkg_version @@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size): return logit_bias -def wrap_kernel_launcher(kernel): - """A faster launcher for triton kernels.""" - if int(triton.__version__.split(".")[0]) >= 3: - return None - - gpu_id = torch.cuda.current_device() - kernels = kernel.cache[gpu_id].values() - kernel = next(iter(kernels)) - - # Different trition versions use different low-level names - if hasattr(kernel, "cu_function"): - kfunction = kernel.cu_function - else: - kfunction = kernel.function - - if hasattr(kernel, "c_wrapper"): - run = kernel.c_wrapper - else: - run = kernel.run - - add_cluster_dim = True - - def ret_func(grid, num_warps, *args): - nonlocal add_cluster_dim - - try: - if add_cluster_dim: - run( - grid[0], - grid[1], - grid[2], - num_warps, - 1, - 1, - 1, - 1, - kernel.shared, - 0, - kfunction, - None, - None, - kernel, - *args, - ) - else: - run( - grid[0], - grid[1], - grid[2], - num_warps, - kernel.shared, - 0, - kfunction, - None, - None, - kernel, - *args, - ) - except TypeError: - add_cluster_dim = not add_cluster_dim - ret_func(grid, num_warps, *args) - - return ret_func - - def is_multimodal_model(model): from sglang.srt.model_config import ModelConfig @@ -512,7 +449,6 @@ def get_ip_address(ifname): def send_addrs_to_rank_0(model_port_args, server_args): assert server_args.node_rank != 0 and server_args.dp_size == 1 - import torch.distributed as dist ifname = os.environ.get( "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") @@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args): def receive_addrs(model_port_args, server_args): assert server_args.node_rank == 0 and server_args.dp_size == 1 - import torch.distributed as dist ifname = os.environ.get( "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0") @@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args): dist.barrier() dist.destroy_process_group() + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")