Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

This commit is contained in:
Lianmin Zheng
2024-09-10 17:11:16 -07:00
committed by GitHub
parent 3a6e8b6d78
commit 46094e0c1b
13 changed files with 99 additions and 61 deletions

View File

@@ -61,14 +61,18 @@ class RadixAttention(nn.Module):
# Choose backend
if (
not global_server_args_dict.get("disable_flashinfer", False)
global_server_args_dict["attention_backend"] == "flashinfer"
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
elif global_server_args_dict["attention_backend"] == "triton":
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
else:
raise ValueError(
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
)
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim: