Fix for T4 GPUs (#16)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
description = "A structured generation langauge for LLMs."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.1.3"
|
||||
__version__ = "0.1.4"
|
||||
|
||||
from sglang.api import *
|
||||
from sglang.global_config import global_config
|
||||
|
||||
@@ -6,6 +6,9 @@ import triton.language as tl
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q,
|
||||
@@ -120,7 +123,11 @@ cached_kernel = None
|
||||
|
||||
|
||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||
BLOCK = 128
|
||||
if CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
|
||||
@@ -2,6 +2,10 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.utils import wrap_kernel_launcher
|
||||
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -153,6 +157,9 @@ def _fwd_kernel(
|
||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||
|
||||
|
||||
cached_kernel = None
|
||||
|
||||
|
||||
def extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -175,7 +182,11 @@ def extend_attention_fwd(
|
||||
|
||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
"""
|
||||
BLOCK_M, BLOCK_N = 128, 128
|
||||
if CUDA_CAPABILITY[0] >= 8:
|
||||
BLOCK_M, BLOCK_N = 128, 128
|
||||
else:
|
||||
BLOCK_M, BLOCK_N = 64, 64
|
||||
|
||||
Lq, Lk, Lv, Lo = (
|
||||
q_extend.shape[-1],
|
||||
k_extend.shape[-1],
|
||||
@@ -193,6 +204,40 @@ def extend_attention_fwd(
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_stages = 1
|
||||
|
||||
global cached_kernel
|
||||
if cached_kernel:
|
||||
cached_kernel(
|
||||
grid,
|
||||
num_warps,
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_start_loc_extend,
|
||||
b_seq_len_extend,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
q_extend.stride(1),
|
||||
k_extend.stride(0),
|
||||
k_extend.stride(1),
|
||||
v_extend.stride(0),
|
||||
v_extend.stride(1),
|
||||
o_extend.stride(0),
|
||||
o_extend.stride(1),
|
||||
k_buffer.stride(0),
|
||||
k_buffer.stride(1),
|
||||
v_buffer.stride(0),
|
||||
v_buffer.stride(1),
|
||||
req_to_tokens.stride(0),
|
||||
)
|
||||
return
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -226,6 +271,7 @@ def extend_attention_fwd(
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||
|
||||
|
||||
def redundant_attention(
|
||||
|
||||
@@ -5,6 +5,7 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import rpyc
|
||||
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
if available_size != self.max_total_num_token:
|
||||
logger.warning(
|
||||
warnings.warn(
|
||||
"Warning: "
|
||||
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||
"KV cache pool leak detected!"
|
||||
|
||||
Reference in New Issue
Block a user