Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

This commit is contained in:
Lianmin Zheng
2024-09-10 17:11:16 -07:00
committed by GitHub
parent 3a6e8b6d78
commit 46094e0c1b
13 changed files with 99 additions and 61 deletions

View File

@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
@@ -92,8 +92,8 @@ class ModelRunner:
)
global_server_args_dict.update(
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
"torchao_config": server_args.torchao_config,
@@ -111,7 +111,7 @@ class ModelRunner:
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_running_requests,
server_args.max_total_tokens,
)
self.init_cublas()
@@ -344,8 +344,8 @@ class ModelRunner:
def init_memory_pool(
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
@@ -379,7 +379,7 @@ class ModelRunner:
),
2048,
),
5120,
4096,
)
self.req_to_token_pool = ReqToTokenPool(
@@ -399,7 +399,7 @@ class ModelRunner:
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True
self.server_args.attention_backend = "triton"
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
@@ -424,7 +424,7 @@ class ModelRunner:
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
if self.server_args.attention_backend != "flashinfer":
assert (
self.sliding_window_size is None
), "turn on flashinfer to support window attention"
@@ -491,7 +491,10 @@ class ModelRunner:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
if (
self.server_args.disable_cuda_graph
or self.server_args.attention_backend != "flashinfer"
):
self.cuda_graph_runner = None
return