[bugfix]: use correct causality condition for flashattention, flashinfer, and triton backends (#10172)
This commit is contained in:
@@ -705,7 +705,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
q = q.to(self.kv_cache_dtype)
|
q = q.to(self.kv_cache_dtype)
|
||||||
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
|
||||||
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
|
||||||
causal = not layer.is_cross_attention
|
causal = True
|
||||||
|
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
causal = False
|
||||||
|
|
||||||
# Check if we should use local attention
|
# Check if we should use local attention
|
||||||
use_local_attn = (
|
use_local_attn = (
|
||||||
@@ -1005,7 +1007,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1
|
||||||
else (-1, -1)
|
else (-1, -1)
|
||||||
)
|
)
|
||||||
causal = not layer.is_cross_attention
|
causal = True
|
||||||
|
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
causal = False
|
||||||
|
|
||||||
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|||||||
@@ -728,9 +728,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
causal = True
|
causal = True
|
||||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
save_kv_cache = False
|
|
||||||
causal = False
|
causal = False
|
||||||
|
if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
|
save_kv_cache = False
|
||||||
|
|
||||||
if self.forward_metadata.extend_no_prefix:
|
if self.forward_metadata.extend_no_prefix:
|
||||||
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
||||||
|
|||||||
@@ -794,7 +794,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap)
|
||||||
|
|
||||||
causal = True
|
causal = True
|
||||||
if layer.attn_type == AttentionType.ENCODER_ONLY:
|
if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
|
||||||
causal = False
|
causal = False
|
||||||
|
|
||||||
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
|
||||||
|
|||||||
Reference in New Issue
Block a user