Remove cached triton launcher (#656)
This commit is contained in:
@@ -4,8 +4,6 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
|
||||||
|
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
@@ -119,9 +117,6 @@ def _fwd_kernel(
|
|||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
||||||
|
|
||||||
|
|
||||||
cached_kernel = None
|
|
||||||
|
|
||||||
|
|
||||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||||
if CUDA_CAPABILITY[0] >= 8:
|
if CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
@@ -139,29 +134,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
global cached_kernel
|
|
||||||
if cached_kernel:
|
|
||||||
cached_kernel(
|
|
||||||
grid,
|
|
||||||
num_warps,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
sm_scale,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
o,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
o.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -185,4 +157,3 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
|
||||||
|
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
@@ -172,9 +171,6 @@ def _fwd_kernel(
|
|||||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||||
|
|
||||||
|
|
||||||
cached_kernel = None
|
|
||||||
|
|
||||||
|
|
||||||
def extend_attention_fwd(
|
def extend_attention_fwd(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -222,40 +218,6 @@ def extend_attention_fwd(
|
|||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
global cached_kernel
|
|
||||||
if cached_kernel:
|
|
||||||
cached_kernel(
|
|
||||||
grid,
|
|
||||||
num_warps,
|
|
||||||
q_extend,
|
|
||||||
k_extend,
|
|
||||||
v_extend,
|
|
||||||
o_extend,
|
|
||||||
k_buffer,
|
|
||||||
v_buffer,
|
|
||||||
req_to_tokens,
|
|
||||||
b_req_idx,
|
|
||||||
b_seq_len,
|
|
||||||
b_start_loc_extend,
|
|
||||||
b_seq_len_extend,
|
|
||||||
sm_scale,
|
|
||||||
kv_group_num,
|
|
||||||
q_extend.stride(0),
|
|
||||||
q_extend.stride(1),
|
|
||||||
k_extend.stride(0),
|
|
||||||
k_extend.stride(1),
|
|
||||||
v_extend.stride(0),
|
|
||||||
v_extend.stride(1),
|
|
||||||
o_extend.stride(0),
|
|
||||||
o_extend.stride(1),
|
|
||||||
k_buffer.stride(0),
|
|
||||||
k_buffer.stride(1),
|
|
||||||
v_buffer.stride(0),
|
|
||||||
v_buffer.stride(1),
|
|
||||||
req_to_tokens.stride(0),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -290,7 +252,6 @@ def extend_attention_fwd(
|
|||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
)
|
)
|
||||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
|
||||||
|
|
||||||
|
|
||||||
def redundant_attention(
|
def redundant_attention(
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.server import global_server_args_dict
|
from sglang.srt.server import global_server_args_dict
|
||||||
from sglang.srt.utils import wrap_kernel_launcher
|
|
||||||
|
|
||||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||||
REDUCE_TRITON_TYPE = tl.float32
|
REDUCE_TRITON_TYPE = tl.float32
|
||||||
@@ -162,10 +161,6 @@ def _fwd_kernel_stage2(
|
|||||||
tl.store(out_ptrs, acc)
|
tl.store(out_ptrs, acc)
|
||||||
|
|
||||||
|
|
||||||
cached_kernel_stage1 = None
|
|
||||||
cached_kernel_stage2 = None
|
|
||||||
|
|
||||||
|
|
||||||
def _token_att_m_fwd(
|
def _token_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -194,28 +189,6 @@ def _token_att_m_fwd(
|
|||||||
else:
|
else:
|
||||||
num_warps = 2
|
num_warps = 2
|
||||||
|
|
||||||
global cached_kernel_stage1
|
|
||||||
if cached_kernel_stage1:
|
|
||||||
cached_kernel_stage1(
|
|
||||||
grid,
|
|
||||||
num_warps,
|
|
||||||
q,
|
|
||||||
k_buffer,
|
|
||||||
sm_scale,
|
|
||||||
Req_to_tokens,
|
|
||||||
B_req_idx,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
att_out,
|
|
||||||
Req_to_tokens.stride(0),
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
k_buffer.stride(0),
|
|
||||||
k_buffer.stride(1),
|
|
||||||
att_out.stride(0),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_fwd_kernel_stage1[grid](
|
_fwd_kernel_stage1[grid](
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -238,7 +211,6 @@ def _token_att_m_fwd(
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
cached_kernel_stage1 = wrap_kernel_launcher(_fwd_kernel_stage1)
|
|
||||||
|
|
||||||
|
|
||||||
def _token_softmax_reducev_fwd(
|
def _token_softmax_reducev_fwd(
|
||||||
@@ -257,27 +229,6 @@ def _token_softmax_reducev_fwd(
|
|||||||
|
|
||||||
num_warps = 1
|
num_warps = 1
|
||||||
|
|
||||||
global cached_kernel_stage2
|
|
||||||
if cached_kernel_stage2:
|
|
||||||
cached_kernel_stage2(
|
|
||||||
grid,
|
|
||||||
num_warps,
|
|
||||||
logics,
|
|
||||||
v_buffer,
|
|
||||||
o,
|
|
||||||
req_to_tokens,
|
|
||||||
b_req_idx,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
logics.stride(0),
|
|
||||||
v_buffer.stride(0),
|
|
||||||
v_buffer.stride(1),
|
|
||||||
o.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
req_to_tokens.stride(0),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logics,
|
logics,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
@@ -298,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=3,
|
num_stages=3,
|
||||||
)
|
)
|
||||||
cached_kernel_stage2 = wrap_kernel_launcher(_fwd_kernel_stage2)
|
|
||||||
|
|
||||||
|
|
||||||
def token_attention_fwd(
|
def token_attention_fwd(
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from sglang.srt.utils import (
|
|||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -145,30 +146,6 @@ def _set_global_server_args(server_args: ServerArgs):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _set_ulimit(target_soft_limit=65535):
|
|
||||||
import resource
|
|
||||||
|
|
||||||
resource_type = resource.RLIMIT_NOFILE
|
|
||||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
|
||||||
|
|
||||||
if current_soft >= target_soft_limit:
|
|
||||||
logger.info(
|
|
||||||
f"Current limits are already sufficient: soft={current_soft}, hard={current_hard}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
|
||||||
new_soft, new_hard = resource.getrlimit(resource_type)
|
|
||||||
logger.info(
|
|
||||||
f"Successfully set new limits: soft={new_soft}, hard={new_hard}"
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warn(f"Failed to set new limits: {e}")
|
|
||||||
logger.info(
|
|
||||||
f"Limits remain unchanged: soft={current_soft}, hard={current_hard}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args: Optional[dict] = None,
|
||||||
@@ -186,7 +163,7 @@ def launch_server(
|
|||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
_set_ulimit()
|
set_ulimit()
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
enable_show_time_cost()
|
enable_show_time_cost()
|
||||||
if server_args.disable_disk_cache:
|
if server_args.disable_disk_cache:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import fcntl
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import resource
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
@@ -16,6 +17,7 @@ import numpy as np
|
|||||||
import psutil
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
|||||||
return logit_bias
|
return logit_bias
|
||||||
|
|
||||||
|
|
||||||
def wrap_kernel_launcher(kernel):
|
|
||||||
"""A faster launcher for triton kernels."""
|
|
||||||
if int(triton.__version__.split(".")[0]) >= 3:
|
|
||||||
return None
|
|
||||||
|
|
||||||
gpu_id = torch.cuda.current_device()
|
|
||||||
kernels = kernel.cache[gpu_id].values()
|
|
||||||
kernel = next(iter(kernels))
|
|
||||||
|
|
||||||
# Different trition versions use different low-level names
|
|
||||||
if hasattr(kernel, "cu_function"):
|
|
||||||
kfunction = kernel.cu_function
|
|
||||||
else:
|
|
||||||
kfunction = kernel.function
|
|
||||||
|
|
||||||
if hasattr(kernel, "c_wrapper"):
|
|
||||||
run = kernel.c_wrapper
|
|
||||||
else:
|
|
||||||
run = kernel.run
|
|
||||||
|
|
||||||
add_cluster_dim = True
|
|
||||||
|
|
||||||
def ret_func(grid, num_warps, *args):
|
|
||||||
nonlocal add_cluster_dim
|
|
||||||
|
|
||||||
try:
|
|
||||||
if add_cluster_dim:
|
|
||||||
run(
|
|
||||||
grid[0],
|
|
||||||
grid[1],
|
|
||||||
grid[2],
|
|
||||||
num_warps,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
1,
|
|
||||||
kernel.shared,
|
|
||||||
0,
|
|
||||||
kfunction,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
kernel,
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
run(
|
|
||||||
grid[0],
|
|
||||||
grid[1],
|
|
||||||
grid[2],
|
|
||||||
num_warps,
|
|
||||||
kernel.shared,
|
|
||||||
0,
|
|
||||||
kfunction,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
kernel,
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
add_cluster_dim = not add_cluster_dim
|
|
||||||
ret_func(grid, num_warps, *args)
|
|
||||||
|
|
||||||
return ret_func
|
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal_model(model):
|
def is_multimodal_model(model):
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
|
|
||||||
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
|
|||||||
|
|
||||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
ifname = os.environ.get(
|
ifname = os.environ.get(
|
||||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||||
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
|
|||||||
|
|
||||||
def receive_addrs(model_port_args, server_args):
|
def receive_addrs(model_port_args, server_args):
|
||||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
ifname = os.environ.get(
|
ifname = os.environ.get(
|
||||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||||
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
|
|||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def set_ulimit(target_soft_limit=65535):
|
||||||
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|
||||||
|
if current_soft < target_soft_limit:
|
||||||
|
try:
|
||||||
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user