Refactor attention backend (#1381)
This commit is contained in:
@@ -83,8 +83,8 @@ class ServerArgs:
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Optimization/debug options
|
||||
attention_backend: str = "flashinfer"
|
||||
sampling_backend: str = "flashinfer"
|
||||
attention_backend: Optional[str] = None
|
||||
sampling_backend: Optional[str] = None
|
||||
|
||||
disable_flashinfer: bool = False
|
||||
disable_flashinfer_sampling: bool = False
|
||||
@@ -148,6 +148,17 @@ class ServerArgs:
|
||||
)
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.enable_mla:
|
||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||
self.attention_backend = "triton"
|
||||
|
||||
if self.attention_backend is None:
|
||||
self.attention_backend = "flashinfer"
|
||||
|
||||
if self.sampling_backend is None:
|
||||
self.sampling_backend = "flashinfer"
|
||||
|
||||
# Model-specific patches
|
||||
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user