lazy import attn backends (#4200)
This commit is contained in:
@@ -6,9 +6,7 @@ import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ class CudaGraphRunner:
|
||||
self.stream = graph_capture_context.stream
|
||||
# Reverse the order to enable better memory sharing across cuda graphs.
|
||||
capture_range = (
|
||||
tqdm.tqdm(reversed(self.capture_bs))
|
||||
tqdm.tqdm(list(reversed(self.capture_bs)))
|
||||
if get_tensor_model_parallel_rank() == 0
|
||||
else reversed(self.capture_bs)
|
||||
)
|
||||
|
||||
@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
|
||||
set_custom_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
get_attention_tp_size,
|
||||
@@ -77,7 +72,6 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -779,6 +773,10 @@ class ModelRunner:
|
||||
def init_attention_backend(self):
|
||||
"""Init attention kernel backend."""
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
)
|
||||
|
||||
# Init streams
|
||||
if self.server_args.speculative_algorithm == "EAGLE":
|
||||
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
||||
@@ -794,12 +792,26 @@ class ModelRunner:
|
||||
"Please use `--attention-backend flashinfer`."
|
||||
)
|
||||
if self.server_args.enable_double_sparsity:
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import (
|
||||
DoubleSparseAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = DoubleSparseAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
from sglang.srt.layers.attention.torch_native_backend import (
|
||||
TorchNativeAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
)
|
||||
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user