lazy import attn backends (#4200)

This commit is contained in:
Lianmin Zheng
2025-03-08 00:41:35 -08:00
committed by GitHub
parent 96d0e37fa7
commit 08c4d764a5
4 changed files with 21 additions and 11 deletions

View File

@@ -6,9 +6,7 @@ import torch
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode

View File

@@ -302,7 +302,7 @@ class CudaGraphRunner:
self.stream = graph_capture_context.stream
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(reversed(self.capture_bs))
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)

View File

@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
@@ -77,7 +72,6 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -779,6 +773,10 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
@@ -794,12 +792,26 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
from sglang.srt.layers.attention.double_sparsity_backend import (
DoubleSparseAttnBackend,
)
self.attn_backend = DoubleSparseAttnBackend(self)
else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend,
)
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(