Refactor attention backend (#1381)

This commit is contained in:
Lianmin Zheng
2024-09-11 11:44:26 -07:00
committed by GitHub
parent c03cece42f
commit fec185ce0c
16 changed files with 568 additions and 564 deletions

View File

@@ -83,8 +83,8 @@ class ServerArgs:
json_model_override_args: str = "{}"
# Optimization/debug options
attention_backend: str = "flashinfer"
sampling_backend: str = "flashinfer"
attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
@@ -148,6 +148,17 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# Default kernel backends
if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.attention_backend = "triton"
if self.attention_backend is None:
self.attention_backend = "flashinfer"
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(