From 08c4d764a51e795515d49a2d8aaabdee1ba66ab7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 8 Mar 2025 00:41:35 -0800 Subject: [PATCH] lazy import attn backends (#4200) --- .../srt/layers/attention/triton_backend.py | 4 +--- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 24 ++++++++++++++----- test/srt/test_eagle_infer.py | 2 +- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index a9d726180..b942dee5c 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -6,9 +6,7 @@ import torch import triton from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.layers.attention.flashinfer_backend import ( - create_flashinfer_kv_indices_triton, -) +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6a2bab22a..83c2d88f0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -302,7 +302,7 @@ class CudaGraphRunner: self.stream = graph_capture_context.stream # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( - tqdm.tqdm(reversed(self.capture_bs)) + tqdm.tqdm(list(reversed(self.capture_bs))) if get_tensor_model_parallel_rank() == 0 else reversed(self.capture_bs) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6489ea6ed..8040709a7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,11 +35,6 @@ from sglang.srt.distributed import ( set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state -from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend -from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend -from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend -from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend -from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, get_attention_tp_size, @@ -77,7 +72,6 @@ from sglang.srt.utils import ( set_cpu_offload_max_bytes, set_cuda_arch, ) -from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -779,6 +773,10 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" if self.server_args.attention_backend == "flashinfer": + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + ) + # Init streams if self.server_args.speculative_algorithm == "EAGLE": self.plan_stream_for_flashinfer = torch.cuda.Stream() @@ -794,12 +792,26 @@ class ModelRunner: "Please use `--attention-backend flashinfer`." ) if self.server_args.enable_double_sparsity: + from sglang.srt.layers.attention.double_sparsity_backend import ( + DoubleSparseAttnBackend, + ) + self.attn_backend = DoubleSparseAttnBackend(self) else: + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + self.attn_backend = TritonAttnBackend(self) elif self.server_args.attention_backend == "torch_native": + from sglang.srt.layers.attention.torch_native_backend import ( + TorchNativeAttnBackend, + ) + self.attn_backend = TorchNativeAttnBackend(self) elif self.server_args.attention_backend == "flashinfer_mla": + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + self.attn_backend = FlashInferMLAAttnBackend(self) else: raise ValueError( diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 5b89071b6..a87b6e37b 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase): def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" params = { - "temperature": 0, + "temperature": 0.1, "max_new_tokens": 1024, "skip_special_tokens": False, }