Deprecate --disable-flashinfer and introduce --attention-backend (#1380)
This commit is contained in:
@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -92,8 +92,8 @@ class ModelRunner:
|
||||
)
|
||||
global_server_args_dict.update(
|
||||
{
|
||||
"disable_flashinfer": server_args.disable_flashinfer,
|
||||
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
|
||||
"attention_backend": server_args.attention_backend,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
@@ -111,7 +111,7 @@ class ModelRunner:
|
||||
self.load_model()
|
||||
self.init_memory_pool(
|
||||
min_per_gpu_memory,
|
||||
server_args.max_num_reqs,
|
||||
server_args.max_running_requests,
|
||||
server_args.max_total_tokens,
|
||||
)
|
||||
self.init_cublas()
|
||||
@@ -344,8 +344,8 @@ class ModelRunner:
|
||||
def init_memory_pool(
|
||||
self,
|
||||
total_gpu_memory: int,
|
||||
max_num_reqs: int = None,
|
||||
max_total_tokens: int = None,
|
||||
max_num_reqs: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
if self.server_args.kv_cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
@@ -379,7 +379,7 @@ class ModelRunner:
|
||||
),
|
||||
2048,
|
||||
),
|
||||
5120,
|
||||
4096,
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
@@ -399,7 +399,7 @@ class ModelRunner:
|
||||
)
|
||||
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
||||
# FIXME: temporarily only Triton MLA is supported
|
||||
self.server_args.disable_flashinfer = True
|
||||
self.server_args.attention_backend = "triton"
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -424,7 +424,7 @@ class ModelRunner:
|
||||
|
||||
def init_flashinfer(self):
|
||||
"""Init flashinfer attention kernel wrappers."""
|
||||
if self.server_args.disable_flashinfer:
|
||||
if self.server_args.attention_backend != "flashinfer":
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
@@ -491,7 +491,10 @@ class ModelRunner:
|
||||
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
|
||||
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
||||
if (
|
||||
self.server_args.disable_cuda_graph
|
||||
or self.server_args.attention_backend != "flashinfer"
|
||||
):
|
||||
self.cuda_graph_runner = None
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user