diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index fd487d768..927f1d93c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional import numpy as np import torch @@ -10,6 +10,7 @@ import triton.language as tl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.spec_info import SpecInput diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 473b61ca6..3d551efee 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -728,7 +728,10 @@ class FlashInferAttnBackend(AttentionBackend): ) else: causal = True - if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY: + if ( + layer.is_cross_attention + or layer.attn_type == AttentionType.ENCODER_ONLY + ): causal = False if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY: save_kv_cache = False