Turn on flashinfer by default (#578)
This commit is contained in:
@@ -34,6 +34,8 @@ The core features include:
|
|||||||
pip install "sglang[all]"
|
pip install "sglang[all]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.
|
||||||
|
|
||||||
### Method 2: From source
|
### Method 2: From source
|
||||||
```
|
```
|
||||||
git clone https://github.com/sgl-project/sglang.git
|
git clone https://github.com/sgl-project/sglang.git
|
||||||
@@ -43,7 +45,11 @@ pip install --upgrade pip
|
|||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Next, [install FlashInfer](https://docs.flashinfer.ai/installation.html) for attention CUDA kernels.
|
||||||
|
|
||||||
### Notes
|
### Notes
|
||||||
|
- If you see triton errors, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html).
|
||||||
|
- If you cannot install FlashInfer, you can use the slower triton kernels by adding `--disable-flashinfer` when launching the server.
|
||||||
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`
|
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
@@ -363,7 +369,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
```
|
```
|
||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
||||||
```
|
```
|
||||||
- See [flashinfer.md](docs/flashinfer.md) on accelerating inference using highly optimized CUDA kernels.
|
|
||||||
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
## Flashinfer Mode
|
|
||||||
|
|
||||||
[flashinfer](https://github.com/flashinfer-ai/flashinfer) is a kernel library for LLM serving.
|
|
||||||
It can be used in SGLang runtime to accelerate attention computation.
|
|
||||||
|
|
||||||
### Install flashinfer
|
|
||||||
|
|
||||||
See https://docs.flashinfer.ai/installation.html.
|
|
||||||
|
|
||||||
### Run a Server With Flashinfer Mode
|
|
||||||
|
|
||||||
Add `--enable-flashinfer` argument to enable flashinfer when launching a server.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --enable-flashinfer
|
|
||||||
```
|
|
||||||
@@ -26,7 +26,7 @@ class RadixAttention(nn.Module):
|
|||||||
|
|
||||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||||
|
|
||||||
if global_server_args_dict.get("enable_flashinfer", False):
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
self.prefill_forward = self.prefill_forward_flashinfer
|
self.prefill_forward = self.prefill_forward_flashinfer
|
||||||
self.extend_forward = self.prefill_forward_flashinfer
|
self.extend_forward = self.prefill_forward_flashinfer
|
||||||
self.decode_forward = self.decode_forward_flashinfer
|
self.decode_forward = self.decode_forward_flashinfer
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ class InputMetadata:
|
|||||||
if forward_mode == ForwardMode.EXTEND:
|
if forward_mode == ForwardMode.EXTEND:
|
||||||
ret.init_extend_args()
|
ret.init_extend_args()
|
||||||
|
|
||||||
if global_server_args_dict.get("enable_flashinfer", False):
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
ret.init_flashinfer_args(
|
ret.init_flashinfer_args(
|
||||||
model_runner.model_config.num_attention_heads // tp_size,
|
model_runner.model_config.num_attention_heads // tp_size,
|
||||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||||
@@ -263,7 +263,7 @@ class ModelRunner:
|
|||||||
# Set some global args
|
# Set some global args
|
||||||
global global_server_args_dict
|
global global_server_args_dict
|
||||||
global_server_args_dict = {
|
global_server_args_dict = {
|
||||||
"enable_flashinfer": server_args.enable_flashinfer,
|
"disable_flashinfer": server_args.disable_flashinfer,
|
||||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,7 +359,7 @@ class ModelRunner:
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
def init_flash_infer(self):
|
def init_flash_infer(self):
|
||||||
if global_server_args_dict.get("enable_flashinfer", False):
|
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
BatchDecodeWithPagedKVCacheWrapper,
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class ServerArgs:
|
|||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
enable_flashinfer: bool = False
|
disable_flashinfer: bool = True
|
||||||
attention_reduce_in_fp32: bool = False
|
attention_reduce_in_fp32: bool = False
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_regex_jump_forward: bool = False
|
||||||
@@ -287,9 +287,9 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer",
|
"--disable-flashinfer",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable flashinfer inference kernels",
|
help="Disable flashinfer inference kernels",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--attention-reduce-in-fp32",
|
"--attention-reduce-in-fp32",
|
||||||
@@ -322,7 +322,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
def print_mode_args(self):
|
def print_mode_args(self):
|
||||||
return (
|
return (
|
||||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
f"disable_flashinfer={self.disable_flashinfer}, "
|
||||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
||||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||||
|
|||||||
Reference in New Issue
Block a user