From b1e5a33ae337d20e35e966b8d82a02a913d32689 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 9 Jun 2025 00:22:45 -0700 Subject: [PATCH] Eliminate stream sync to speed up LoRA batch init (#6960) --- python/sglang/srt/lora/layers.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index aa10ef6b7..50d8c3888 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -137,7 +137,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.A_buffer_gate_up = A_buffer if self.lora_backend.fuse_stacked_lora_b: # B_buffer_gate_up: (num_lora, 2 * output_dim, r) - if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None: + if getattr(self, "B_buffer_gate_up", None) is None: self.B_buffer_gate_up = torch.empty( ( B_buffer[0].shape[0], @@ -202,7 +202,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) - if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None: + if getattr(self, "B_buffer_qkv", None) is None: self.B_buffer_qkv = torch.empty( ( B_buffer_q[0].shape[0], @@ -221,20 +221,17 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ) # Offsets of q/k/v in output dimension - if not hasattr(self, "output_offset") or self.output_offset is None: - self.output_offset = torch.empty( - 4, dtype=torch.int32, device=B_buffer_q.device + if getattr(self, "output_offset", None) is None: + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, ) - self.output_offset[:4] = torch.tensor( - [ - 0, - output_dim_q, - output_dim_q + output_dim_kv, - output_dim_q + 2 * output_dim_kv, - ], - dtype=torch.int32, - device=B_buffer_q.device, - ) # For computing number of launched blocks self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) else: