Remove cached triton launcher (#656)

This commit is contained in:
Lianmin Zheng
2024-07-18 23:28:40 -07:00
committed by GitHub
parent 1b7adbb5a0
commit e1792cca24
5 changed files with 15 additions and 210 deletions

View File

@@ -5,6 +5,7 @@ import fcntl
import logging
import os
import random
import resource
import socket
import struct
import time
@@ -16,6 +17,7 @@ import numpy as np
import psutil
import requests
import torch
import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
return logit_bias
def wrap_kernel_launcher(kernel):
"""A faster launcher for triton kernels."""
if int(triton.__version__.split(".")[0]) >= 3:
return None
gpu_id = torch.cuda.current_device()
kernels = kernel.cache[gpu_id].values()
kernel = next(iter(kernels))
# Different trition versions use different low-level names
if hasattr(kernel, "cu_function"):
kfunction = kernel.cu_function
else:
kfunction = kernel.function
if hasattr(kernel, "c_wrapper"):
run = kernel.c_wrapper
else:
run = kernel.run
add_cluster_dim = True
def ret_func(grid, num_warps, *args):
nonlocal add_cluster_dim
try:
if add_cluster_dim:
run(
grid[0],
grid[1],
grid[2],
num_warps,
1,
1,
1,
1,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
else:
run(
grid[0],
grid[1],
grid[2],
num_warps,
kernel.shared,
0,
kfunction,
None,
None,
kernel,
*args,
)
except TypeError:
add_cluster_dim = not add_cluster_dim
ret_func(grid, num_warps, *args)
return ret_func
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
dist.barrier()
dist.destroy_process_group()
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")