Remove cached triton launcher (#656)
This commit is contained in:
@@ -4,8 +4,6 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@@ -119,9 +117,6 @@ def _fwd_kernel(
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
if CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK = 128
|
||||
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sm_scale,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
o,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q,
|
||||
k,
|
||||
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
|
||||
@@ -3,7 +3,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -172,9 +171,6 @@ def _fwd_kernel(
|
||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -222,40 +218,6 @@ def extend_attention_fwd(
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_stages = 1
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
q_extend.stride(1),
|
||||
k_extend.stride(0),
|
||||
k_extend.stride(1),
|
||||
v_extend.stride(0),
|
||||
v_extend.stride(1),
|
||||
o_extend.stride(0),
|
||||
o_extend.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -290,7 +252,6 @@ def extend_attention_fwd(
|
||||
num_stages=num_stages,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
|
||||
|
||||
def redundant_attention(
|
||||
|
||||
@@ -6,7 +6,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.server import global_server_args_dict
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
cached_kernel_stage1 = None
|
||||
cached_kernel_stage2 = None
|
||||
|
||||
|
||||
def _token_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
|
||||
else:
|
||||
num_warps = 2
|
||||
|
||||
global cached_kernel_stage1
|
||||
if cached_kernel_stage1:
|
||||
cached_kernel_stage1(
|
||||
grid,
|
||||
num_warps,
|
||||
q,
|
||||
k_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_req_idx,
|
||||
B_Start_Loc,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
att_out.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
||||
|
||||
|
||||
def _token_softmax_reducev_fwd(
|
||||
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
|
||||
|
||||
num_warps = 1
|
||||
|
||||
global cached_kernel_stage2
|
||||
if cached_kernel_stage2:
|
||||
cached_kernel_stage2(
|
||||
grid,
|
||||
num_warps,
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
logics.stride(0),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel_stage2[grid](
|
||||
logics,
|
||||
v_buffer,
|
||||
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
||||
num_warps=num_warps,
|
||||
num_stages=3,
|
||||
)
|
||||
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
||||
|
||||
|
||||
def token_attention_fwd(
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.utils import (
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
|
||||
}
|
||||
|
||||
|
||||
def _set_ulimit(target_soft_limit=65535):
|
||||
import resource
|
||||
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft >= target_soft_limit:
|
||||
logger.info(
|
||||
f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
new_soft, new_hard = resource.getrlimit(resource_type)
|
||||
logger.info(
|
||||
f"Successfully set new limits: soft={new_soft}, hard={new_hard}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warn(f"Failed to set new limits: {e}")
|
||||
logger.info(
|
||||
f"Limits remain unchanged: soft={current_soft}, hard={current_hard}"
|
||||
)
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
model_overide_args: Optional[dict] = None,
|
||||
@@ -186,7 +163,7 @@ def launch_server(
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
_set_ulimit()
|
||||
set_ulimit()
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
if server_args.disable_disk_cache:
|
||||
|
||||
@@ -5,6 +5,7 @@ import fcntl
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import resource
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
@@ -16,6 +17,7 @@ import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
return logit_bias
|
||||
|
||||
|
||||
def wrap_kernel_launcher(kernel):
|
||||
"""A faster launcher for triton kernels."""
|
||||
if int(triton.__version__.split(".")[0]) >= 3:
|
||||
return None
|
||||
|
||||
gpu_id = torch.cuda.current_device()
|
||||
kernels = kernel.cache[gpu_id].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
if hasattr(kernel, "cu_function"):
|
||||
kfunction = kernel.cu_function
|
||||
else:
|
||||
kfunction = kernel.function
|
||||
|
||||
if hasattr(kernel, "c_wrapper"):
|
||||
run = kernel.c_wrapper
|
||||
else:
|
||||
run = kernel.run
|
||||
|
||||
add_cluster_dim = True
|
||||
|
||||
def ret_func(grid, num_warps, *args):
|
||||
nonlocal add_cluster_dim
|
||||
|
||||
try:
|
||||
if add_cluster_dim:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
except TypeError:
|
||||
add_cluster_dim = not add_cluster_dim
|
||||
ret_func(grid, num_warps, *args)
|
||||
|
||||
return ret_func
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
|
||||
|
||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
Reference in New Issue
Block a user