Remove cached triton launcher (#656)
This commit is contained in:
@@ -5,6 +5,7 @@ import fcntl
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import resource
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
@@ -16,6 +17,7 @@ import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
@@ -184,71 +186,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
return logit_bias
|
||||
|
||||
|
||||
def wrap_kernel_launcher(kernel):
|
||||
"""A faster launcher for triton kernels."""
|
||||
if int(triton.__version__.split(".")[0]) >= 3:
|
||||
return None
|
||||
|
||||
gpu_id = torch.cuda.current_device()
|
||||
kernels = kernel.cache[gpu_id].values()
|
||||
kernel = next(iter(kernels))
|
||||
|
||||
# Different trition versions use different low-level names
|
||||
if hasattr(kernel, "cu_function"):
|
||||
kfunction = kernel.cu_function
|
||||
else:
|
||||
kfunction = kernel.function
|
||||
|
||||
if hasattr(kernel, "c_wrapper"):
|
||||
run = kernel.c_wrapper
|
||||
else:
|
||||
run = kernel.run
|
||||
|
||||
add_cluster_dim = True
|
||||
|
||||
def ret_func(grid, num_warps, *args):
|
||||
nonlocal add_cluster_dim
|
||||
|
||||
try:
|
||||
if add_cluster_dim:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
else:
|
||||
run(
|
||||
grid[0],
|
||||
grid[1],
|
||||
grid[2],
|
||||
num_warps,
|
||||
kernel.shared,
|
||||
0,
|
||||
kfunction,
|
||||
None,
|
||||
None,
|
||||
kernel,
|
||||
*args,
|
||||
)
|
||||
except TypeError:
|
||||
add_cluster_dim = not add_cluster_dim
|
||||
ret_func(grid, num_warps, *args)
|
||||
|
||||
return ret_func
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
@@ -512,7 +449,6 @@ def get_ip_address(ifname):
|
||||
|
||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
@@ -544,7 +480,6 @@ def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
@@ -577,3 +512,14 @@ def receive_addrs(model_port_args, server_args):
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
Reference in New Issue
Block a user