Simplify flashinfer dispatch (#1552)

This commit is contained in:
Liangsheng Yin
2024-10-01 00:28:42 -07:00
committed by GitHub
parent 619bb6ddda
commit 100f5b8bc9
5 changed files with 97 additions and 76 deletions

View File

@@ -231,6 +231,7 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
@@ -453,6 +454,10 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert not self.has_cross_attention, (
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
else:
raise ValueError(