diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 733f75178..6658517ae 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -35,7 +35,7 @@ "\n", "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", "\n", - "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", + "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index e1bdc5408..fe8bd3d20 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -5,22 +5,6 @@ import torch from sglang.srt.lora.utils import LoRABatchInfo -def get_fuse_output_add_from_name(name: str) -> bool: - mapping = { - "triton": True, - "flashinfer": False, - } - return mapping.get(name, False) - - -def get_fuse_stacked_lora_b_from_name(name: str) -> bool: - mapping = { - "triton": True, - "flashinfer": False, - } - return mapping.get(name, False) - - class BaseLoRABackend: """Base class for different Lora backends. Each backend has its own implementation of Lora kernels. @@ -28,15 +12,11 @@ class BaseLoRABackend: Args: name: name of backend batch_info: information of current batch for use - fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward, - and the operation of adding will be fused into kernel """ def __init__(self, name: str, batch_info: LoRABatchInfo = None): self.name = name self.batch_info = batch_info - self.fuse_output_add = get_fuse_output_add_from_name(name) - self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name) def run_lora_a_sgemm( self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs @@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: return TritonLoRABackend elif name == "flashinfer": - from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend - - return FlashInferLoRABackend + raise ValueError( + "FlashInfer LoRA backend has been deprecated, please use `triton` instead." + ) else: raise ValueError(f"Invalid backend: {name}") diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py deleted file mode 100644 index 0370c6c81..000000000 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Tuple - -import torch - -from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.utils import LoRABatchInfo -from sglang.srt.utils import is_flashinfer_available - -if is_flashinfer_available(): - from flashinfer import SegmentGEMMWrapper - - -class FlashInferLoRABackend(BaseLoRABackend): - - def __init__(self, name: str, batch_info: LoRABatchInfo = None): - super().__init__(name, batch_info) - - # Set up SGemm Wrapper from flashinfer - # FIXME wait for flashinfer segment gemm update - workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") - self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) - - def run_lora_a_sgemm( - self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs - ) -> torch.Tensor: - - return self.segment_gemm.run( - x=x, - weights=weights, - batch_size=self.batch_info.bs, - weight_column_major=True, - seg_indptr=self.batch_info.seg_indptr, - weight_indices=self.batch_info.weight_indices, - ) - - def run_lora_b_sgemm( - self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs - ) -> torch.Tensor: - - return ( - self.segment_gemm.run( - x=x, - weights=weights, - batch_size=self.batch_info.bs, - weight_column_major=True, - seg_indptr=self.batch_info.seg_indptr, - weight_indices=self.batch_info.weight_indices, - ) - * self.batch_info.scalings[0] - ) - - def run_qkv_lora( - self, - x: torch.Tensor, - qkv_lora_a: torch.Tensor, - qkv_lora_b: Tuple[torch.Tensor], - *args, - **kwargs, - ) -> torch.Tensor: - - assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2 - - # Shape of lora_a_output: (s, 3 * r) - lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) - - q_lora_b, kv_lora_b = qkv_lora_b - lora_rank = kv_lora_b.shape[-1] - output_dim_q = q_lora_b.shape[-2] - output_dim_kv = kv_lora_b.shape[-2] - lora_output = torch.empty( - (x.shape[0], output_dim_q + 2 * output_dim_kv), - device=x.device, - dtype=x.dtype, - ) - - # q - lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( - x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] - ) - - # kv - lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = ( - self.run_lora_b_sgemm( - x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), - weights=kv_lora_b[0], - ) - ) - - lora_output[ - :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv - ] = self.run_lora_b_sgemm( - x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), - weights=kv_lora_b[1], - ) - - return lora_output * self.batch_info.scalings[0] - - def run_gate_up_lora( - self, - x: torch.Tensor, - gate_up_lora_a: torch.Tensor, - gate_up_lora_b: Tuple[torch.Tensor], - *args, - **kwargs, - ) -> torch.Tensor: - - assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2 - lora_rank = gate_up_lora_b[0].shape[-1] - output_dim = gate_up_lora_b[0].shape[-2] - - # Shape of lora_a_output: (s, 2 * r) - lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a) - - lora_output = torch.empty( - (x.shape[0], 2 * output_dim), - device=x.device, - dtype=x.dtype, - ) - - # Compute lora for gate and up proj respectively - lora_output[:, :output_dim] = self.run_lora_b_sgemm( - x=lora_a_output[:, :lora_rank].contiguous(), - weights=gate_up_lora_b[0], - ) - - lora_output[:, output_dim:] = self.run_lora_b_sgemm( - x=lora_a_output[:, lora_rank:].contiguous(), - weights=gate_up_lora_b[1], - ) - - return lora_output * self.batch_info.scalings[0] diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 50d8c3888..4328a7601 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import torch from torch import nn @@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): self.B_buffer = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - backend_kwargs = {"base_output": base_output} lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_output = self.lora_backend.run_lora_b_sgemm( - lora_a_output, - self.B_buffer[0], - **backend_kwargs, - ) - return ( - lora_output - if self.lora_backend.fuse_output_add - else base_output + lora_output + x=lora_a_output, + weights=self.B_buffer, + base_output=base_output, ) + return lora_output def forward(self, input_: torch.Tensor): # duplicate the logic in ColumnParallelLinear @@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ): self.set_lora = True self.A_buffer_gate_up = A_buffer - if self.lora_backend.fuse_stacked_lora_b: - # B_buffer_gate_up: (num_lora, 2 * output_dim, r) - if getattr(self, "B_buffer_gate_up", None) is None: - self.B_buffer_gate_up = torch.empty( - ( - B_buffer[0].shape[0], - 2 * B_buffer[0].shape[1], - B_buffer[0].shape[2], - ), - dtype=B_buffer[0].dtype, - device=B_buffer[0].device, - ) - self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0]) - self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1]) - else: - self.B_buffer_gate_up = (B_buffer[0], B_buffer[1]) + self.B_buffer_gate_up = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - backend_kwargs = {"base_output": base_output} - lora_output = self.lora_backend.run_gate_up_lora( - x, - self.A_buffer_gate_up, - self.B_buffer_gate_up, - **backend_kwargs, - ) - return ( - lora_output - if self.lora_backend.fuse_output_add - else base_output + lora_output + x=x, + gate_up_lora_a=self.A_buffer_gate_up, + gate_up_lora_b=self.B_buffer_gate_up, + base_output=base_output, ) + return lora_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): return A @@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int): # Since the outputs for both gate and up are identical, we use a random one. shard_size = self.base_layer.output_partition_sizes[0] + gate_size = self.base_layer.output_sizes[0] start_idx = tp_rank * shard_size end_idx = (tp_rank + 1) * shard_size - return B[:, start_idx:end_idx, :] + return torch.concat( + ( + B[start_idx:end_idx, :], + B[gate_size + start_idx : gate_size + end_idx], + ), + dim=0, + ) class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_backend: BaseLoRABackend, ) -> None: super().__init__(base_layer, lora_backend) + q_proj_shard_size = self.base_layer.q_proj_shard_size + kv_proj_shard_size = self.base_layer.kv_proj_shard_size + self.output_offset = torch.tensor( + [ + 0, + q_proj_shard_size, + q_proj_shard_size + kv_proj_shard_size, + q_proj_shard_size + 2 * kv_proj_shard_size, + ], + dtype=torch.int32, + device=next(self.base_layer.parameters()).device, + ) + + # For computing number of launched blocks + self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size) def set_lora_info( self, A_buffer_qkv: torch.Tensor, - B_buffer_q: torch.Tensor, - B_buffer_kv: torch.Tensor, + B_buffer_qkv: torch.Tensor, ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv - - if self.lora_backend.fuse_stacked_lora_b: - assert ( - B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] - ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" - output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] - - # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) - if getattr(self, "B_buffer_qkv", None) is None: - self.B_buffer_qkv = torch.empty( - ( - B_buffer_q[0].shape[0], - output_dim_q + 2 * output_dim_kv, - B_buffer_q[0].shape[2], - ), - dtype=B_buffer_q[0].dtype, - device=B_buffer_q[0].device, - ) - self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0]) - self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_( - B_buffer_kv[0] - ) - self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_( - B_buffer_kv[1] - ) - - # Offsets of q/k/v in output dimension - if getattr(self, "output_offset", None) is None: - self.output_offset = torch.tensor( - [ - 0, - output_dim_q, - output_dim_q + output_dim_kv, - output_dim_q + 2 * output_dim_kv, - ], - dtype=torch.int32, - device=B_buffer_q.device, - ) - # For computing number of launched blocks - self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) - else: - self.B_buffer_qkv = ( - B_buffer_q, - B_buffer_kv, - ) + self.B_buffer_qkv = B_buffer_qkv def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - backend_kwargs = {"base_output": base_output} - if self.lora_backend.fuse_stacked_lora_b: - backend_kwargs["output_offset"] = self.output_offset - backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim - lora_output = self.lora_backend.run_qkv_lora( - x, - self.A_buffer_qkv, - self.B_buffer_qkv, - **backend_kwargs, - ) - return ( - lora_output - if self.lora_backend.fuse_output_add - else base_output + lora_output + x=x, + qkv_lora_a=self.A_buffer_qkv, + qkv_lora_b=self.B_buffer_qkv, + base_output=base_output, + output_offset=self.output_offset, + max_qkv_out_dim=self.max_qkv_out_dim, ) + return lora_output def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int): return A - def slice_lora_b_weights( - self, B: List[torch.Tensor], tp_rank: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - B_q, B_kv = B + def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor: base_layer = self.base_layer q_proj_shard_size = base_layer.q_proj_shard_size kv_proj_shard_size = base_layer.kv_proj_shard_size @@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): kv_start_idx = kv_proj_shard_size * kv_shard_id kv_end_idx = kv_start_idx + kv_proj_shard_size - return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :] + q_size, k_size, _ = base_layer.output_sizes + B_q_shard = B[q_start_idx:q_end_idx, :] + B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :] + B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :] + + return torch.concat( + ( + B_q_shard, + B_k_shard, + B_v_shard, + ), + dim=0, + ) class RowParallelLinearWithLoRA(BaseLayerWithLoRA): @@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.B_buffer = B_buffer def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - backend_kwargs = {"base_output": base_output} lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_output = self.lora_backend.run_lora_b_sgemm( - lora_a_output, - self.B_buffer[0], - **backend_kwargs, - ) - return ( - lora_output - if self.lora_backend.fuse_output_add - else base_output + lora_output + x=lora_a_output, + weights=self.B_buffer, + base_output=base_output, ) + return lora_output def forward(self, input_: torch.Tensor): # duplicate the logic in RowParallelLinear diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 7bc6af532..dfd5acda9 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module): q_name = weight_name k_name = weight_name.replace("q_proj", "k_proj") v_name = weight_name.replace("q_proj", "v_proj") - kv_name = weight_name.replace("q_proj", "kv_proj") qkv_name = weight_name.replace("q_proj", "qkv_proj") # If k_proj doesn't have lora, initialize it to zero @@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module): if "k_proj" in target_module else torch.zeros_like(weights[v_name]) ) - if "lora_A" in weight_name: - weights[qkv_name] = torch.cat( - ( - weights[q_name], - k_proj_weight, - weights[v_name], - ), - 0, - ) - weights.pop(q_name) - if "k_proj" in target_module: - weights.pop(k_name) - weights.pop(v_name) - else: - weights[kv_name] = torch.stack( - [ - k_proj_weight, - weights[v_name], - ], - dim=0, - ) - if "k_proj" in target_module: - weights.pop(k_name) - weights.pop(v_name) + weights[qkv_name] = torch.cat( + ( + weights[q_name], + k_proj_weight, + weights[v_name], + ), + 0, + ) + weights.pop(q_name) + if "k_proj" in target_module: + weights.pop(k_name) + weights.pop(v_name) elif "qkv_proj" in weight_name: # If qkv_proj is already stacked, we normalize it following the SGL convention. qkv_name = weight_name q_name = weight_name.replace("qkv_proj", "q_proj") k_name = weight_name.replace("qkv_proj", "k_proj") v_name = weight_name.replace("qkv_proj", "v_proj") - kv_name = weight_name.replace("qkv_proj", "kv_proj") if "lora_A" in weight_name: weights[qkv_name] = weights[qkv_name].repeat(3, 1) - else: - head_size = ( - self.base_hf_config.hidden_size - // self.base_hf_config.num_attention_heads - ) - weights[q_name], k_proj_weight, v_proj_weight = torch.split( - weights[qkv_name], - [ - head_size * self.base_hf_config.num_attention_heads, - head_size * self.base_hf_config.num_key_value_heads, - head_size * self.base_hf_config.num_key_value_heads, - ], - dim=0, - ) - weights[kv_name] = torch.stack( - [k_proj_weight, v_proj_weight], - dim=0, - ) + # else: no-op as LoRA B weight is already stacked. def normalize_gate_up_proj( self, weight_names: List[str], weights: Dict[str, torch.Tensor] @@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module): gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") if up_name not in weights: weights[up_name] = torch.zeros_like(weights[weight_name]) - # FIXME: Add gate-only support for flashinfer in future implementations assert self.lora_backend.name == "triton", ( f"LoRA weight initialization currently only supported for 'triton' backend. " f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"or consider implementing custom initialization logic for other backends." ) - if "lora_A" in weight_name: - weights[gate_up_name] = torch.cat( - (weights[weight_name], weights[up_name]), 0 - ) - else: - weights[gate_up_name] = torch.stack( - [weights[weight_name], weights[up_name]], dim=0 - ) + weights[gate_up_name] = torch.cat( + (weights[weight_name], weights[up_name]), 0 + ) weights.pop(weight_name) if up_name in weights: weights.pop(up_name) @@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module): gate_up_name = weight_name if "lora_A" in weight_name: weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) - else: - output_dim = weights[gate_up_name].shape[0] // 2 - weights[gate_up_name] = torch.stack( - [ - weights[gate_up_name][:output_dim, :], - weights[gate_up_name][output_dim:, :], - ], - dim=0, - ) + # else: no-op as LoRA B weight is already stacked. diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index e9fdd0a11..3ab93c73b 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRABatchInfo, LoRAType, - get_customized_names_from_hf_names, get_layer_id, get_normalized_lora_weight_names, get_weight_name, @@ -345,40 +344,19 @@ class LoRAManager: ) self.lora_backend.set_batch_info(batch_info) - # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call - # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch. - self.update_lora_info() - def update_lora_info(self): """ Update all LoRA modules to associate them with the latest memory buffer. """ for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): - if "qkv_proj" in module_name: - module.set_lora_info( - self.memory_pool.get_tensor( - "qkv_proj", layer_id, LoRAType.LORA_A - ), - self.memory_pool.get_tensor( - "q_proj", layer_id, LoRAType.LORA_B - ), - self.memory_pool.get_tensor( - "kv_proj", layer_id, LoRAType.LORA_B - ), - ) - else: - weight_name = get_weight_name( - module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A - ) - module.set_lora_info( - self.memory_pool.get_tensor( - weight_name, layer_id, LoRAType.LORA_A - ), - self.memory_pool.get_tensor( - weight_name, layer_id, LoRAType.LORA_B - ), - ) + weight_name = get_weight_name( + module_name, self.memory_pool.lora_weight_names + ) + module.set_lora_info( + self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A), + self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B), + ) def init_state( self, @@ -405,6 +383,7 @@ class LoRAManager: self.init_lora_weight_names() self.init_lora_modules() self.init_memory_pool() + self.update_lora_info() def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): # Configs of all active LoRA adapters, indexed by LoRA ID. @@ -461,9 +440,9 @@ class LoRAManager: Add new LoRA weight names if needed based on the current `self.configs`. """ - # Target lora weight names for lora_a and lora_b modules respectively. - lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) - self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B)) + self.lora_weight_names: Set[str] = get_normalized_lora_weight_names( + self.target_modules + ) def load_lora_weights(self, lora_ref: LoRARef): """ @@ -479,15 +458,6 @@ class LoRAManager: lora_adapter.initialize_weights() self.loras[lora_ref.lora_id] = lora_adapter - # Additional checks for flashinfer backend - # FIXME remove the restrictions after supporting multi-rank for flashinfer backend - if self.lora_backend == "flashinfer": - lora_dims = set(x.r for x in self.configs.values()) - scalings = set(x.scaling for x in self.loras.values()) - assert ( - len(lora_dims) == 1 and len(scalings) == 1 - ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " - def init_memory_pool(self): """(Re)initialize the LoRA memory pool based on the current configurations.""" self.memory_pool = LoRAMemoryPool( @@ -512,12 +482,6 @@ class LoRAManager: {} for _ in range(self.base_hf_config.num_hidden_layers) ] - # Target module names of customized layers defined in python/sglang/srt/layers - # e.g., {"qkv_proj", "o_proj"} - customized_target_names = get_customized_names_from_hf_names( - self.target_modules, self.base_model - ) - for module_name, module in self.base_model.named_modules(): # TODO (lifuhuang): in the future, we should consider generalizing the # should_apply_lora function to support mapping by full module name instead @@ -530,7 +494,7 @@ class LoRAManager: continue # The module should be converted if it is included in target_names - if module_name.split(".")[-1] in customized_target_names: + if module_name.split(".")[-1] in self.lora_weight_names: layer_id = get_layer_id(module_name) self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index cc00c7212..56cd39d67 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -52,7 +52,7 @@ class LoRAMemoryPool: tp_size: int, tp_rank: int, max_lora_rank: int, - lora_weight_names: Tuple[Set[str], Set[str]], + lora_weight_names: Set[str], base_model: torch.nn.Module, ): self.base_hf_config: AutoConfig = base_hf_config @@ -62,9 +62,7 @@ class LoRAMemoryPool: self.tp_size: int = tp_size self.tp_rank: int = tp_rank self.max_lora_rank: int = max_lora_rank - - # lora weight names for LoRA A and B respectively. - self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names + self.lora_weight_names: Set[str] = lora_weight_names # Both A_buffer and B_buffer maps lora weight names to its buffer space. # A_buffer contains num_layer number of row-major tensors with shape @@ -97,12 +95,8 @@ class LoRAMemoryPool: """ if config.r > self.max_lora_rank: return False - weights_a, weights_b = get_normalized_lora_weight_names( - config.target_modules - ) - return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset( - self.lora_weight_names[1] - ) + weights = get_normalized_lora_weight_names(config.target_modules) + return weights.issubset(self.lora_weight_names) if isinstance(config, LoRAConfig): return _can_support(config) @@ -132,11 +126,9 @@ class LoRAMemoryPool: Given a module_name (might be a stacked name), return the hidden dims of modules' input and output. """ _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model) - c = get_stacked_multiply(module_name) if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES: output_dim = divide(output_dim, self.tp_size) return ( - c, self.max_loras_per_batch, output_dim, max_lora_dim, @@ -165,13 +157,13 @@ class LoRAMemoryPool: init_buffer( self.A_buffer, - self.lora_weight_names[0], + self.lora_weight_names, self.get_lora_A_shape, ) init_buffer( self.B_buffer, - self.lora_weight_names[1], + self.lora_weight_names, self.get_lora_B_shape, ) @@ -246,7 +238,7 @@ class LoRAMemoryPool: return assert lora_adapter is not None - lora_rank = lora_adapter.config.hf_config["r"] + lora_rank = lora_adapter.config.r for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { @@ -256,73 +248,38 @@ class LoRAMemoryPool: weight_name: None for weight_name in self.B_buffer } for name, weights in layer_weights.items(): + lora_weight_name = get_weight_name(name, self.lora_weight_names) if "lora_A" in name: - lora_weight_name = get_weight_name( - name, self.lora_weight_names, LoRAType.LORA_A - ) temp_A_buffer[lora_weight_name] = weights else: - lora_weight_name = get_weight_name( - name, self.lora_weight_names, LoRAType.LORA_B - ) temp_B_buffer[lora_weight_name] = weights if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): - weight_name = get_weight_name( - module_name, self.lora_weight_names, LoRAType.LORA_A - ) + weight_name = get_weight_name(module_name, self.lora_weight_names) if temp_A_buffer[weight_name] is None: # Skip weight slicing if the weight is not present in the adapter continue - if "qkv_proj" in module_name: - temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights( - temp_A_buffer["qkv_proj"], self.tp_rank - ) - temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = ( - module.slice_lora_b_weights( - [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]], - self.tp_rank, - ) - ) - else: - # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B. - # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and - # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the - # FlashInfer LoRA backend. - temp_A_buffer[weight_name] = module.slice_lora_a_weights( - temp_A_buffer[weight_name], self.tp_rank - ) - temp_B_buffer[weight_name] = module.slice_lora_b_weights( - temp_B_buffer[weight_name], self.tp_rank - ) + temp_A_buffer[weight_name] = module.slice_lora_a_weights( + temp_A_buffer[weight_name], self.tp_rank + ) + temp_B_buffer[weight_name] = module.slice_lora_b_weights( + temp_B_buffer[weight_name], self.tp_rank + ) for name, weights in temp_A_buffer.items(): c = get_stacked_multiply(name) - buffer_view = self.A_buffer[name][layer_id][buffer_id][ - : lora_rank * c, : - ] + target_buffer = self.A_buffer[name][layer_id] + buffer_view = target_buffer[buffer_id, : lora_rank * c, :] load_lora_weight_tensor(buffer_view, weights) for name, weights in temp_B_buffer.items(): - c = get_stacked_multiply(name) - if c > 1: - for stacked_id in range(c): - buffer_view = self.B_buffer[name][layer_id][stacked_id][ - buffer_id - ][:, :lora_rank] - weight_slice = ( - weights[stacked_id] if weights is not None else None - ) - load_lora_weight_tensor(buffer_view, weight_slice) - else: - buffer_view = self.B_buffer[name][layer_id][0][buffer_id][ - :, :lora_rank - ] - load_lora_weight_tensor(buffer_view, weights) + target_buffer = self.B_buffer[name][layer_id] + buffer_view = target_buffer[buffer_id, :, :lora_rank] + load_lora_weight_tensor(buffer_view, weights) def get_tensor( self, weight_name: str, layer_id: int, lora_type: LoRAType diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py index e685e526b..1d6663dbe 100644 --- a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -119,7 +119,7 @@ def _qkv_lora_b_kernel( output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 ) - output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size) partial_sum += tl.load(output_ptr, mask=output_mask) tl.store(output_ptr, partial_sum, mask=output_mask) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 61642cba5..e5aa43eff 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int: return int(match.group(1)) -def get_customized_names_from_hf_names( - hf_module_names: Set[str], base_model: torch.nn.Module -) -> Set[str]: - """ - This function takes in a set of huggingface style module names: - e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} - and outputs a set of module names of customized sglang layers: - e.g., {"qkv_proj", "o_proj"} - """ - if hasattr(base_model, "get_module_name"): - return {base_model.get_module_name(name) for name in hf_module_names} - else: - """ - Fallback solution of mapping from config module name to module name in model class. - Please check if it aligns with your base model. - Please implement the function in the model class if it is not. - You can reference this function in llama.py. - """ - params_mapping = { - "q_proj": "qkv_proj", - "k_proj": "qkv_proj", - "v_proj": "qkv_proj", - "gate_proj": "gate_up_proj", - "up_proj": "gate_up_proj", - } - return {params_mapping.get(name, name) for name in hf_module_names} - - def get_hidden_dim( module_name: str, config: AutoConfig, base_model: torch.nn.Module ) -> Tuple[int]: @@ -95,22 +67,9 @@ def get_hidden_dim( head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) - - # TODO: the special handling of qkv will be addressed in #8940. if module_name == "qkv_proj": - return ( - config.hidden_size, - None, # qkv_proj is only used in LoRA A - ) - elif module_name == "kv_proj": - return ( - None, # kv_proj is only used in LoRA B - head_dim * config.num_key_value_heads, - ) - elif module_name == "q_proj": - return ( - None, # q_proj is only used in LoRA B - head_dim * config.num_attention_heads, + return config.hidden_size, head_dim * ( + config.num_attention_heads + config.num_key_value_heads * 2 ) elif module_name == "o_proj": return ( @@ -118,7 +77,7 @@ def get_hidden_dim( config.hidden_size, ) elif module_name == "gate_up_proj": - return config.hidden_size, config.intermediate_size + return config.hidden_size, config.intermediate_size * 2 elif module_name == "down_proj": return config.intermediate_size, config.hidden_size else: @@ -127,26 +86,22 @@ def get_hidden_dim( def get_normalized_lora_weight_names( target_modules: Iterable[str], -) -> Tuple[set[str], set[str]]: +) -> set[str]: """ Mapping a list of target module name to names of the normalized LoRA weights. - Returned tuple contains (name for Lora A, name for Lora B) """ params_mapping = { - "q_proj": (["qkv_proj"], ["q_proj"]), - "k_proj": (["qkv_proj"], ["kv_proj"]), - "v_proj": (["qkv_proj"], ["kv_proj"]), - "gate_proj": (["gate_up_proj"], ["gate_up_proj"]), - "up_proj": (["gate_up_proj"], ["gate_up_proj"]), - "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]), - "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]), + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", } - result = (set(), set()) + result = set() for name in target_modules: - lora_a, lora_b = params_mapping.get(name, ([name], [name])) - result[0].update(lora_a) - result[1].update(lora_b) + weight_name = params_mapping.get(name, name) + result.add(weight_name) return result @@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int: """ stacked_rank = { "qkv_proj": 3, - "kv_proj": 2, "gate_up_proj": 2, } return stacked_rank[module_name] if module_name in stacked_rank else 1 def get_weight_name( - target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType + target_name: str, lora_weight_names: Tuple[Set[str]] ) -> Optional[str]: """ - target_name is name of a given module, - lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above) + Get the weight name in lora_weight_names that can match target_name. + If there is a weight name in lora_weight_names that can match target_name, return this name Else raise ValueError. """ - idx = 0 if lora_type == LoRAType.LORA_A else 1 - for weight_name in lora_weight_names[idx]: + for weight_name in lora_weight_names: if weight_name in target_name: return weight_name raise ValueError( @@ -180,9 +133,4 @@ def get_weight_name( ) -# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names. -VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"] -COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"] -MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"] -QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"] ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"] diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index f9c58eaae..fa9a10c85 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): def get_hidden_dim(self, module_name): # return input_dim, output_dim - # TODO: the special handling of qkv will be addressed in #8940. if module_name == "qkv_proj": return ( self.config.hidden_size, - None, # qkv_proj is only used in LoRA A + self.config.head_dim + * ( + self.config.num_attention_heads + + self.config.num_key_value_heads * 2 + ), ) - elif module_name == "kv_proj": - return ( - None, # kv_proj is only used in LoRA B - self.config.head_dim * self.config.num_key_value_heads, - ) - elif module_name == "q_proj": - return ( - None, # q_proj is only used in LoRA B - self.config.head_dim * self.config.num_attention_heads, - ) - elif module_name in ["o_proj"]: + elif module_name == "o_proj": return ( self.config.head_dim * self.config.num_attention_heads, self.config.hidden_size, @@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): "Currently SGLang requires uniform intermediate size for all layers. " "Please file an issue if you need support for non-uniform intermediate sizes." ) - return self.config.hidden_size, self.config.intermediate_size[0] + return self.config.hidden_size, self.config.intermediate_size[0] * 2 elif module_name == "down_proj": assert len(set(self.config.intermediate_size)) == 1, ( "Currently SGLang requires uniform intermediate size for all layers. "