Fix for T4 GPUs (#16)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-01-16 15:49:03 -08:00
committed by GitHub
parent 5b27a1dce4
commit ffe4aaee1d
6 changed files with 68 additions and 6 deletions

View File

@@ -32,6 +32,10 @@ pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
``` ```
### Notes
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
## Quick Start ## Quick Start
The example below shows how to use sglang to answer a mulit-turn question. The example below shows how to use sglang to answer a mulit-turn question.
@@ -197,7 +201,7 @@ for out in state.text_iter():
## Backend: SGLang Runtime (SRT) ## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server. However, it can also be used as a standalone API server.
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases. In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
### Usage ### Usage
Launch a server Launch a server
@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
``` ```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
``` ```
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
### Supported Models ### Supported Models
- Llama - Llama

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "sglang" name = "sglang"
version = "0.1.3" version = "0.1.4"
description = "A structured generation langauge for LLMs." description = "A structured generation langauge for LLMs."
readme = "README.md" readme = "README.md"
requires-python = ">=3.8" requires-python = ">=3.8"

View File

@@ -1,4 +1,4 @@
__version__ = "0.1.3" __version__ = "0.1.4"
from sglang.api import * from sglang.api import *
from sglang.global_config import global_config from sglang.global_config import global_config

View File

@@ -6,6 +6,9 @@ import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, Q,
@@ -120,7 +123,11 @@ cached_kernel = None
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
if CUDA_CAPABILITY[0] >= 8:
BLOCK = 128 BLOCK = 128
else:
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128} assert Lk in {16, 32, 64, 128}

View File

@@ -2,6 +2,10 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY = torch.cuda.get_device_capability()
@triton.jit @triton.jit
@@ -153,6 +157,9 @@ def _fwd_kernel(
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
cached_kernel = None
def extend_attention_fwd( def extend_attention_fwd(
q_extend, q_extend,
k_extend, k_extend,
@@ -175,7 +182,11 @@ def extend_attention_fwd(
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
""" """
if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = 128, 128 BLOCK_M, BLOCK_N = 128, 128
else:
BLOCK_M, BLOCK_N = 64, 64
Lq, Lk, Lv, Lo = ( Lq, Lk, Lv, Lo = (
q_extend.shape[-1], q_extend.shape[-1],
k_extend.shape[-1], k_extend.shape[-1],
@@ -193,6 +204,40 @@ def extend_attention_fwd(
num_warps = 4 if Lk <= 64 else 8 num_warps = 4 if Lk <= 64 else 8
num_stages = 1 num_stages = 1
global cached_kernel
if cached_kernel:
cached_kernel(
grid,
num_warps,
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_start_loc_extend,
b_seq_len_extend,
sm_scale,
kv_group_num,
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
)
return
_fwd_kernel[grid]( _fwd_kernel[grid](
q_extend, q_extend,
k_extend, k_extend,
@@ -226,6 +271,7 @@ def extend_attention_fwd(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
def redundant_attention( def redundant_attention(

View File

@@ -5,6 +5,7 @@ import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import warnings
import numpy as np import numpy as np
import rpyc import rpyc
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_token: if available_size != self.max_total_num_token:
logger.warning( warnings.warn(
"Warning: " "Warning: "
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n" f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
"KV cache pool leak detected!" "KV cache pool leak detected!"