Fix for T4 GPUs (#16)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
10
README.md
10
README.md
@@ -32,6 +32,10 @@ pip install --upgrade pip
|
|||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Notes
|
||||||
|
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
|
||||||
|
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
The example below shows how to use sglang to answer a mulit-turn question.
|
The example below shows how to use sglang to answer a mulit-turn question.
|
||||||
|
|
||||||
@@ -197,7 +201,7 @@ for out in state.text_iter():
|
|||||||
## Backend: SGLang Runtime (SRT)
|
## Backend: SGLang Runtime (SRT)
|
||||||
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
||||||
However, it can also be used as a standalone API server.
|
However, it can also be used as a standalone API server.
|
||||||
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases.
|
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
Launch a server
|
Launch a server
|
||||||
@@ -221,6 +225,10 @@ curl http://localhost:30000/v1/completions \
|
|||||||
```
|
```
|
||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
|
||||||
```
|
```
|
||||||
|
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
|
||||||
|
```
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
||||||
|
```
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
- Llama
|
- Llama
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sglang"
|
name = "sglang"
|
||||||
version = "0.1.3"
|
version = "0.1.4"
|
||||||
description = "A structured generation langauge for LLMs."
|
description = "A structured generation langauge for LLMs."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.1.3"
|
__version__ = "0.1.4"
|
||||||
|
|
||||||
from sglang.api import *
|
from sglang.api import *
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ import triton.language as tl
|
|||||||
from sglang.srt.utils import wrap_kernel_launcher
|
from sglang.srt.utils import wrap_kernel_launcher
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
Q,
|
Q,
|
||||||
@@ -120,7 +123,11 @@ cached_kernel = None
|
|||||||
|
|
||||||
|
|
||||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||||
|
if CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK = 128
|
BLOCK = 128
|
||||||
|
else:
|
||||||
|
BLOCK = 64
|
||||||
|
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk and Lk == Lv
|
assert Lq == Lk and Lk == Lv
|
||||||
assert Lk in {16, 32, 64, 128}
|
assert Lk in {16, 32, 64, 128}
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||||
|
from sglang.srt.utils import wrap_kernel_launcher
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -153,6 +157,9 @@ def _fwd_kernel(
|
|||||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
||||||
|
|
||||||
|
|
||||||
|
cached_kernel = None
|
||||||
|
|
||||||
|
|
||||||
def extend_attention_fwd(
|
def extend_attention_fwd(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -175,7 +182,11 @@ def extend_attention_fwd(
|
|||||||
|
|
||||||
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||||
"""
|
"""
|
||||||
|
if CUDA_CAPABILITY[0] >= 8:
|
||||||
BLOCK_M, BLOCK_N = 128, 128
|
BLOCK_M, BLOCK_N = 128, 128
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = 64, 64
|
||||||
|
|
||||||
Lq, Lk, Lv, Lo = (
|
Lq, Lk, Lv, Lo = (
|
||||||
q_extend.shape[-1],
|
q_extend.shape[-1],
|
||||||
k_extend.shape[-1],
|
k_extend.shape[-1],
|
||||||
@@ -193,6 +204,40 @@ def extend_attention_fwd(
|
|||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
|
global cached_kernel
|
||||||
|
if cached_kernel:
|
||||||
|
cached_kernel(
|
||||||
|
grid,
|
||||||
|
num_warps,
|
||||||
|
q_extend,
|
||||||
|
k_extend,
|
||||||
|
v_extend,
|
||||||
|
o_extend,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_tokens,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
b_start_loc_extend,
|
||||||
|
b_seq_len_extend,
|
||||||
|
sm_scale,
|
||||||
|
kv_group_num,
|
||||||
|
q_extend.stride(0),
|
||||||
|
q_extend.stride(1),
|
||||||
|
k_extend.stride(0),
|
||||||
|
k_extend.stride(1),
|
||||||
|
v_extend.stride(0),
|
||||||
|
v_extend.stride(1),
|
||||||
|
o_extend.stride(0),
|
||||||
|
o_extend.stride(1),
|
||||||
|
k_buffer.stride(0),
|
||||||
|
k_buffer.stride(1),
|
||||||
|
v_buffer.stride(0),
|
||||||
|
v_buffer.stride(1),
|
||||||
|
req_to_tokens.stride(0),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -226,6 +271,7 @@ def extend_attention_fwd(
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
)
|
)
|
||||||
|
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
||||||
|
|
||||||
|
|
||||||
def redundant_attention(
|
def redundant_attention(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import time
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rpyc
|
import rpyc
|
||||||
@@ -164,7 +165,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
if available_size != self.max_total_num_token:
|
if available_size != self.max_total_num_token:
|
||||||
logger.warning(
|
warnings.warn(
|
||||||
"Warning: "
|
"Warning: "
|
||||||
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||||
"KV cache pool leak detected!"
|
"KV cache pool leak detected!"
|
||||||
|
|||||||
Reference in New Issue
Block a user