[Fix] Fix dual chunk model default behavior (#9032)
This commit is contained in:
@@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
"""Initialize CUDA graph state for the attention backend.
|
"""Initialize CUDA graph state for the attention backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -388,6 +388,19 @@ class ModelRunner:
|
|||||||
): # override the default attention backend
|
): # override the default attention backend
|
||||||
server_args.attention_backend = server_args.prefill_attention_backend
|
server_args.attention_backend = server_args.prefill_attention_backend
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
if server_args.attention_backend is None:
|
||||||
|
server_args.attention_backend = "dual_chunk_flash_attn"
|
||||||
|
logger.info("Dual chunk attention is turned on by default.")
|
||||||
|
elif server_args.attention_backend != "dual_chunk_flash_attn":
|
||||||
|
raise ValueError(
|
||||||
|
"Dual chunk attention is enabled, but attention backend is set to "
|
||||||
|
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.attention_backend is None:
|
if server_args.attention_backend is None:
|
||||||
"""
|
"""
|
||||||
Auto select the fastest attention backend.
|
Auto select the fastest attention backend.
|
||||||
|
|||||||
Reference in New Issue
Block a user