Turn on flashinfer by default (#578)
This commit is contained in:
@@ -26,7 +26,7 @@ class RadixAttention(nn.Module):
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
||||
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
self.prefill_forward = self.prefill_forward_flashinfer
|
||||
self.extend_forward = self.prefill_forward_flashinfer
|
||||
self.decode_forward = self.decode_forward_flashinfer
|
||||
|
||||
@@ -201,7 +201,7 @@ class InputMetadata:
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
ret.init_extend_args()
|
||||
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
@@ -263,7 +263,7 @@ class ModelRunner:
|
||||
# Set some global args
|
||||
global global_server_args_dict
|
||||
global_server_args_dict = {
|
||||
"enable_flashinfer": server_args.enable_flashinfer,
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
||||
}
|
||||
|
||||
@@ -359,7 +359,7 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flash_infer(self):
|
||||
if global_server_args_dict.get("enable_flashinfer", False):
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
from flashinfer import (
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
|
||||
@@ -50,7 +50,7 @@ class ServerArgs:
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Optimization/debug options
|
||||
enable_flashinfer: bool = False
|
||||
disable_flashinfer: bool = True
|
||||
attention_reduce_in_fp32: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
@@ -287,9 +287,9 @@ class ServerArgs:
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
"--disable-flashinfer",
|
||||
action="store_true",
|
||||
help="Enable flashinfer inference kernels",
|
||||
help="Disable flashinfer inference kernels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
@@ -322,7 +322,7 @@ class ServerArgs:
|
||||
|
||||
def print_mode_args(self):
|
||||
return (
|
||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||
f"disable_flashinfer={self.disable_flashinfer}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
|
||||
Reference in New Issue
Block a user