Simplify flashinfer dispatch (#1552)
This commit is contained in:
@@ -231,6 +231,7 @@ class ModelRunner:
|
||||
if hasattr(self.model, "get_attention_sliding_window_size")
|
||||
else None
|
||||
)
|
||||
self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
@@ -453,6 +454,10 @@ class ModelRunner:
|
||||
"Window attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
assert not self.has_cross_attention, (
|
||||
"Cross attention is not supported in the triton attention backend. "
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user