diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 279a6dbd5..fd487d768 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -705,7 +705,9 @@ class FlashAttentionBackend(AttentionBackend): q = q.to(self.kv_cache_dtype) q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None - causal = not layer.is_cross_attention + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False # Check if we should use local attention use_local_attn = ( @@ -1005,7 +1007,9 @@ class FlashAttentionBackend(AttentionBackend): if layer.sliding_window_size is not None and layer.sliding_window_size > -1 else (-1, -1) ) - causal = not layer.is_cross_attention + causal = True + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + causal = False # For fa3 interface version compatibility, we put new fields into conditional keyword args kwargs = {} diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 9f09a268a..473b61ca6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -728,9 +728,10 @@ class FlashInferAttnBackend(AttentionBackend): ) else: causal = True - if layer.attn_type == AttentionType.ENCODER_ONLY: - save_kv_cache = False + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: causal = False + if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY: + save_kv_cache = False if self.forward_metadata.extend_no_prefix: # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a483670db..71c034dd7 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -794,7 +794,7 @@ class TritonAttnBackend(AttentionBackend): logits_soft_cap = logit_capping_mod(layer.logit_capping_method, layer.logit_cap) causal = True - if layer.attn_type == AttentionType.ENCODER_ONLY: + if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: causal = False if layer.sliding_window_size is not None and layer.sliding_window_size > -1: