Fix for T4 GPUs (#16)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-01-16 15:49:03 -08:00
committed by GitHub
parent 5b27a1dce4
commit ffe4aaee1d
6 changed files with 68 additions and 6 deletions

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang"
version = "0.1.3"
version = "0.1.4"
description = "A structured generation langauge for LLMs."
readme = "README.md"
requires-python = ">=3.8"

View File

@@ -1,4 +1,4 @@
__version__ = "0.1.3"
__version__ = "0.1.4"
from sglang.api import *
from sglang.global_config import global_config

View File

@@ -6,6 +6,9 @@ import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
def _fwd_kernel(
Q,
@@ -120,7 +123,11 @@ cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 128
if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}

View File

@@ -2,6 +2,10 @@ import torch
import triton
import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit
@@ -153,6 +157,9 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd(
q_extend,
k_extend,
@@ -175,7 +182,11 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
BLOCK_M, BLOCK_N = 128, 128
if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = 128, 128
else:
BLOCK_M, BLOCK_N = 64, 64
Lq, Lk, Lv, Lo = (
q_extend.shape[-1],
k_extend.shape[-1],
@@ -193,6 +204,40 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8
num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid](
q_extend,
k_extend,
@@ -226,6 +271,7 @@ def extend_attention_fwd(
num_warps=num_warps,
num_stages=num_stages,
)
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention(

View File

@@ -5,6 +5,7 @@ import time
from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
import rpyc
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
+ self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_token:
logger.warning(
warnings.warn(
"Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!"